Li-Ruixiao commited on
Commit
8b576d1
·
1 Parent(s): 7db4b79

init model

Browse files
README.md ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MossAudioTokenizer (remote code)
2
+
3
+ MossAudioTokenizer is a neural audio codec model for audio tokenization and synthesis.
4
+
5
+ This repository contains a lightweight “remote code” implementation that mirrors the current 🤗 Transformers
6
+ `transformers.models.moss_audio_tokenizer` module. It is intended to be uploaded to a Hugging Face Hub model repository
7
+ and loaded with `trust_remote_code=True` when needed.
8
+
9
+ ## Quickstart
10
+
11
+ ```python
12
+ import torch
13
+ from transformers import AutoModel
14
+
15
+ repo_id = "OpenMOSS-Team/MOSS-Audio-Tokenizer"
16
+ model = AutoModel.from_pretrained(repo_id, trust_remote_code=True).eval()
17
+
18
+ audio = torch.randn(1, 1, 3200) # dummy waveform
19
+ enc = model.encode(audio, return_dict=True)
20
+ dec = model.decode(enc.audio_codes, return_dict=True)
21
+ ```
22
+
23
+ ## Streaming
24
+
25
+ `MossAudioTokenizerModel.encode` and `MossAudioTokenizerModel.decode` support simple streaming via a `chunk_duration`
26
+ argument.
27
+
28
+ - `chunk_duration` is expressed in seconds.
29
+ - It must be <= `MossAudioTokenizerConfig.causal_transformer_context_duration`.
30
+ - `chunk_duration * MossAudioTokenizerConfig.sampling_rate` must be divisible by `MossAudioTokenizerConfig.downsample_rate`.
31
+ - Current limitation: streaming chunking only supports `batch_size=1`.
32
+
33
+ ```python
34
+ import torch
35
+ from transformers import AutoModel
36
+
37
+ repo_id = "<org-or-user>/<model-repo>"
38
+ model = AutoModel.from_pretrained(repo_id, trust_remote_code=True).eval()
39
+ audio = torch.randn(1, 1, 3200) # dummy waveform
40
+
41
+ # 0.08s @ 24kHz = 1920 samples, divisible by downsample_rate=1920
42
+ enc = model.encode(audio, return_dict=True, chunk_duration=0.08)
43
+ dec = model.decode(enc.audio_codes, return_dict=True, chunk_duration=0.08)
44
+ ```
45
+
46
+ ## Repository layout
47
+
48
+ Remote-code modules:
49
+ - `configuration_moss_audio_tokenizer.py`
50
+ - `modeling_moss_audio_tokenizer.py`
51
+ - `__init__.py`
52
+
53
+ Hub model files:
54
+ - `config.json`
55
+ - model weights
56
+
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Remote code package for Moss audio tokenizer."""
config.json ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MossAudioTokenizerModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": [
7
+ null,
8
+ "configuration_moss_audio_tokenizer.MossAudioTokenizerConfig"
9
+ ],
10
+ "AutoModel": [
11
+ null,
12
+ "modeling_moss_audio_tokenizer.MossAudioTokenizerModel"
13
+ ]
14
+ },
15
+ "causal_transformer_context_duration": 10,
16
+ "code_dim": 768,
17
+ "decoder_kwargs": [
18
+ {
19
+ "causal": true,
20
+ "conv_layout": true,
21
+ "d_model": 1280,
22
+ "dim_feedforward": 5120,
23
+ "gating": "none",
24
+ "input_dimension": 768,
25
+ "layer_scale": 0.01,
26
+ "max_period": 10000,
27
+ "module_type": "Transformer",
28
+ "norm": "layer_norm",
29
+ "num_heads": 20,
30
+ "num_layers": 32,
31
+ "output_dimension": 1280,
32
+ "positional_embedding": "rope"
33
+ },
34
+ {
35
+ "module_type": "PatchedPretransform",
36
+ "patch_size": 2
37
+ },
38
+ {
39
+ "causal": true,
40
+ "conv_layout": true,
41
+ "d_model": 768,
42
+ "dim_feedforward": 3072,
43
+ "gating": "none",
44
+ "input_dimension": 640,
45
+ "layer_scale": 0.01,
46
+ "max_period": 10000,
47
+ "module_type": "Transformer",
48
+ "norm": "layer_norm",
49
+ "num_heads": 12,
50
+ "num_layers": 12,
51
+ "output_dimension": 768,
52
+ "positional_embedding": "rope"
53
+ },
54
+ {
55
+ "module_type": "PatchedPretransform",
56
+ "patch_size": 2
57
+ },
58
+ {
59
+ "causal": true,
60
+ "conv_layout": true,
61
+ "d_model": 768,
62
+ "dim_feedforward": 3072,
63
+ "gating": "none",
64
+ "input_dimension": 384,
65
+ "layer_scale": 0.01,
66
+ "max_period": 10000,
67
+ "module_type": "Transformer",
68
+ "norm": "layer_norm",
69
+ "num_heads": 12,
70
+ "num_layers": 12,
71
+ "output_dimension": 768,
72
+ "positional_embedding": "rope"
73
+ },
74
+ {
75
+ "module_type": "PatchedPretransform",
76
+ "patch_size": 2
77
+ },
78
+ {
79
+ "causal": true,
80
+ "conv_layout": true,
81
+ "d_model": 768,
82
+ "dim_feedforward": 3072,
83
+ "gating": "none",
84
+ "input_dimension": 384,
85
+ "layer_scale": 0.01,
86
+ "max_period": 10000,
87
+ "module_type": "Transformer",
88
+ "norm": "layer_norm",
89
+ "num_heads": 12,
90
+ "num_layers": 12,
91
+ "output_dimension": 240,
92
+ "positional_embedding": "rope"
93
+ },
94
+ {
95
+ "module_type": "PatchedPretransform",
96
+ "patch_size": 240
97
+ }
98
+ ],
99
+ "downsample_rate": 1920,
100
+ "dtype": "float32",
101
+ "encoder_kwargs": [
102
+ {
103
+ "module_type": "PatchedPretransform",
104
+ "patch_size": 240
105
+ },
106
+ {
107
+ "causal": true,
108
+ "conv_layout": true,
109
+ "d_model": 768,
110
+ "dim_feedforward": 3072,
111
+ "gating": "none",
112
+ "input_dimension": 240,
113
+ "layer_scale": 0.01,
114
+ "max_period": 10000,
115
+ "module_type": "Transformer",
116
+ "norm": "layer_norm",
117
+ "num_heads": 12,
118
+ "num_layers": 12,
119
+ "output_dimension": 384,
120
+ "positional_embedding": "rope"
121
+ },
122
+ {
123
+ "module_type": "PatchedPretransform",
124
+ "patch_size": 2
125
+ },
126
+ {
127
+ "causal": true,
128
+ "conv_layout": true,
129
+ "d_model": 768,
130
+ "dim_feedforward": 3072,
131
+ "gating": "none",
132
+ "input_dimension": 768,
133
+ "layer_scale": 0.01,
134
+ "max_period": 10000,
135
+ "module_type": "Transformer",
136
+ "norm": "layer_norm",
137
+ "num_heads": 12,
138
+ "num_layers": 12,
139
+ "output_dimension": 384,
140
+ "positional_embedding": "rope"
141
+ },
142
+ {
143
+ "module_type": "PatchedPretransform",
144
+ "patch_size": 2
145
+ },
146
+ {
147
+ "causal": true,
148
+ "conv_layout": true,
149
+ "d_model": 768,
150
+ "dim_feedforward": 3072,
151
+ "gating": "none",
152
+ "input_dimension": 768,
153
+ "layer_scale": 0.01,
154
+ "max_period": 10000,
155
+ "module_type": "Transformer",
156
+ "norm": "layer_norm",
157
+ "num_heads": 12,
158
+ "num_layers": 12,
159
+ "output_dimension": 640,
160
+ "positional_embedding": "rope"
161
+ },
162
+ {
163
+ "module_type": "PatchedPretransform",
164
+ "patch_size": 2
165
+ },
166
+ {
167
+ "causal": true,
168
+ "conv_layout": true,
169
+ "d_model": 1280,
170
+ "dim_feedforward": 5120,
171
+ "gating": "none",
172
+ "input_dimension": 1280,
173
+ "layer_scale": 0.01,
174
+ "max_period": 10000,
175
+ "module_type": "Transformer",
176
+ "norm": "layer_norm",
177
+ "num_heads": 20,
178
+ "num_layers": 32,
179
+ "output_dimension": 768,
180
+ "positional_embedding": "rope"
181
+ }
182
+ ],
183
+ "model_type": "speech_tokenizer",
184
+ "quantizer_kwargs": {
185
+ "codebook_dim": 8,
186
+ "codebook_size": 1024,
187
+ "input_dim": 768,
188
+ "num_quantizers": 32,
189
+ "output_dim": 768,
190
+ "quantizer_type": "rlfq",
191
+ "rvq_dim": 512
192
+ },
193
+ "quantizer_type": "rlfq",
194
+ "reversed_decoder_kwargs": [
195
+ {
196
+ "module_type": "PatchedPretransform",
197
+ "patch_size": 240
198
+ },
199
+ {
200
+ "causal": true,
201
+ "conv_layout": true,
202
+ "d_model": 768,
203
+ "dim_feedforward": 3072,
204
+ "gating": "none",
205
+ "input_dimension": 240,
206
+ "layer_scale": 0.01,
207
+ "max_period": 10000,
208
+ "module_type": "Transformer",
209
+ "norm": "layer_norm",
210
+ "num_heads": 12,
211
+ "num_layers": 12,
212
+ "output_dimension": 384,
213
+ "positional_embedding": "rope"
214
+ },
215
+ {
216
+ "module_type": "PatchedPretransform",
217
+ "patch_size": 2
218
+ },
219
+ {
220
+ "causal": true,
221
+ "conv_layout": true,
222
+ "d_model": 768,
223
+ "dim_feedforward": 3072,
224
+ "gating": "none",
225
+ "input_dimension": 768,
226
+ "layer_scale": 0.01,
227
+ "max_period": 10000,
228
+ "module_type": "Transformer",
229
+ "norm": "layer_norm",
230
+ "num_heads": 12,
231
+ "num_layers": 12,
232
+ "output_dimension": 384,
233
+ "positional_embedding": "rope"
234
+ },
235
+ {
236
+ "module_type": "PatchedPretransform",
237
+ "patch_size": 2
238
+ },
239
+ {
240
+ "causal": true,
241
+ "conv_layout": true,
242
+ "d_model": 768,
243
+ "dim_feedforward": 3072,
244
+ "gating": "none",
245
+ "input_dimension": 768,
246
+ "layer_scale": 0.01,
247
+ "max_period": 10000,
248
+ "module_type": "Transformer",
249
+ "norm": "layer_norm",
250
+ "num_heads": 12,
251
+ "num_layers": 12,
252
+ "output_dimension": 640,
253
+ "positional_embedding": "rope"
254
+ },
255
+ {
256
+ "module_type": "PatchedPretransform",
257
+ "patch_size": 2
258
+ },
259
+ {
260
+ "causal": true,
261
+ "conv_layout": true,
262
+ "d_model": 1280,
263
+ "dim_feedforward": 5120,
264
+ "gating": "none",
265
+ "input_dimension": 1280,
266
+ "layer_scale": 0.01,
267
+ "max_period": 10000,
268
+ "module_type": "Transformer",
269
+ "norm": "layer_norm",
270
+ "num_heads": 20,
271
+ "num_layers": 32,
272
+ "output_dimension": 768,
273
+ "positional_embedding": "rope"
274
+ }
275
+ ],
276
+ "sample_rate": 24000,
277
+ "sampling_rate": 24000,
278
+ "transformers_version": "4.56.0.dev0",
279
+ "version": "4.26.1.a"
280
+ }
configuration_moss_audio_tokenizer.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """MossAudioTokenizer model configuration"""
16
+
17
+ from typing import Any
18
+
19
+ from transformers.configuration_utils import PreTrainedConfig
20
+ from transformers.utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class MossAudioTokenizerConfig(PreTrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`MossAudioTokenizerModel`]. It is used to instantiate a
29
+ MossAudioTokenizer model according to the specified arguments, defining the model architecture.
30
+
31
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the
32
+ [VoiceAgentGroup/moss_audio_tokenizer](https://huggingface.co/VoiceAgentGroup/moss_audio_tokenizer) architecture.
33
+
34
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PreTrainedConfig`] for more information.
36
+
37
+ Args:
38
+ sampling_rate (`int`, *optional*, defaults to 24000):
39
+ The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
40
+ downsample_rate (`int`, *optional*, defaults to 1920):
41
+ Total downsampling rate from waveform to tokens.
42
+ causal_transformer_context_duration (`float`, *optional*, defaults to 10.0):
43
+ Context duration in seconds for causal transformer.
44
+ encoder_kwargs (`list[dict]`, *optional*):
45
+ List of encoder module configurations. Each dict specifies a module type and its parameters.
46
+ decoder_kwargs (`list[dict]`, *optional*):
47
+ List of decoder module configurations in execution order.
48
+ quantizer_type (`str`, *optional*, defaults to `"rvq"`):
49
+ Quantizer type. Options include `"rvq"`, `"spec_rvq"`, `"rlfq"`, `"random_prefix_rlfq"`.
50
+ quantizer_kwargs (`dict`, *optional*):
51
+ Configuration for the quantizer including `input_dim`, `rvq_dim`, `output_dim`, `num_quantizers`,
52
+ `codebook_size`, and `codebook_dim`.
53
+
54
+ Example:
55
+
56
+ ```python
57
+ >>> from transformers import MossAudioTokenizerModel, MossAudioTokenizerConfig
58
+
59
+ >>> # Initializing a MossAudioTokenizer style configuration
60
+ >>> configuration = MossAudioTokenizerConfig()
61
+
62
+ >>> # Initializing a model (with random weights) from the configuration
63
+ >>> model = MossAudioTokenizerModel(configuration)
64
+
65
+ >>> # Accessing the model configuration
66
+ >>> configuration = model.config
67
+ ```
68
+ """
69
+
70
+ model_type = "moss-audio-tokenizer"
71
+
72
+ # Backward-compatible alias used by some checkpoints.
73
+ attribute_map = {"sample_rate": "sampling_rate"}
74
+
75
+ sampling_rate: int
76
+ downsample_rate: int
77
+ causal_transformer_context_duration: float
78
+ encoder_kwargs: list[dict[str, Any]]
79
+ decoder_kwargs: list[dict[str, Any]]
80
+ quantizer_type: str
81
+ quantizer_kwargs: dict[str, Any]
82
+
83
+ def __init__(
84
+ self,
85
+ version: str | None = None,
86
+ sampling_rate: int = 24000,
87
+ downsample_rate: int = 1920,
88
+ causal_transformer_context_duration: float = 10.0,
89
+ encoder_kwargs: list[dict[str, Any]] | None = None,
90
+ decoder_kwargs: list[dict[str, Any]] | None = None,
91
+ quantizer_type: str = "rlfq",
92
+ quantizer_kwargs: dict[str, Any] | None = None,
93
+ **kwargs,
94
+ ):
95
+ # Some checkpoints might include an incorrect/legacy `model_type` (e.g. "speech_tokenizer").
96
+ # We drop it to avoid overriding the class-level `model_type`.
97
+ kwargs.pop("model_type", None)
98
+
99
+ # `version` is accepted for compatibility but not used in modeling.
100
+ self.version = version
101
+ self.sampling_rate = sampling_rate
102
+ self.downsample_rate = downsample_rate
103
+ self.causal_transformer_context_duration = causal_transformer_context_duration
104
+ # Default encoder configuration
105
+ if encoder_kwargs is None:
106
+ encoder_kwargs = [
107
+ {
108
+ "module_type": "PatchedPretransform",
109
+ "patch_size": 240,
110
+ },
111
+ {
112
+ "module_type": "Transformer",
113
+ "input_dimension": 240,
114
+ "output_dimension": 384,
115
+ "d_model": 768,
116
+ "num_heads": 12,
117
+ "num_layers": 12,
118
+ "dim_feedforward": 3072,
119
+ "causal": True,
120
+ "norm": "layer_norm",
121
+ "positional_embedding": "rope",
122
+ "max_period": 10000,
123
+ "gating": "none",
124
+ "layer_scale": 0.01,
125
+ "conv_layout": True,
126
+ },
127
+ {
128
+ "module_type": "PatchedPretransform",
129
+ "patch_size": 2,
130
+ },
131
+ {
132
+ "module_type": "Transformer",
133
+ "input_dimension": 768,
134
+ "output_dimension": 384,
135
+ "d_model": 768,
136
+ "num_heads": 12,
137
+ "num_layers": 12,
138
+ "dim_feedforward": 3072,
139
+ "causal": True,
140
+ "norm": "layer_norm",
141
+ "positional_embedding": "rope",
142
+ "max_period": 10000,
143
+ "gating": "none",
144
+ "layer_scale": 0.01,
145
+ "conv_layout": True,
146
+ },
147
+ {
148
+ "module_type": "PatchedPretransform",
149
+ "patch_size": 2,
150
+ },
151
+ {
152
+ "module_type": "Transformer",
153
+ "input_dimension": 768,
154
+ "output_dimension": 640,
155
+ "d_model": 768,
156
+ "num_heads": 12,
157
+ "num_layers": 12,
158
+ "dim_feedforward": 3072,
159
+ "causal": True,
160
+ "norm": "layer_norm",
161
+ "positional_embedding": "rope",
162
+ "max_period": 10000,
163
+ "gating": "none",
164
+ "layer_scale": 0.01,
165
+ "conv_layout": True,
166
+ },
167
+ {
168
+ "module_type": "PatchedPretransform",
169
+ "patch_size": 2,
170
+ },
171
+ {
172
+ "module_type": "Transformer",
173
+ "input_dimension": 1280,
174
+ "output_dimension": 768,
175
+ "d_model": 1280,
176
+ "num_heads": 20,
177
+ "num_layers": 32,
178
+ "dim_feedforward": 5120,
179
+ "causal": True,
180
+ "norm": "layer_norm",
181
+ "positional_embedding": "rope",
182
+ "max_period": 10000,
183
+ "gating": "none",
184
+ "layer_scale": 0.01,
185
+ "conv_layout": True,
186
+ },
187
+ ]
188
+ self.encoder_kwargs = encoder_kwargs
189
+
190
+ # Default decoder configuration (execution order)
191
+ if decoder_kwargs is None:
192
+ decoder_kwargs = [
193
+ {
194
+ "module_type": "Transformer",
195
+ "input_dimension": 768,
196
+ "output_dimension": 1280,
197
+ "d_model": 1280,
198
+ "num_heads": 20,
199
+ "num_layers": 32,
200
+ "dim_feedforward": 5120,
201
+ "causal": True,
202
+ "norm": "layer_norm",
203
+ "positional_embedding": "rope",
204
+ "max_period": 10000,
205
+ "gating": "none",
206
+ "layer_scale": 0.01,
207
+ "conv_layout": True,
208
+ },
209
+ {
210
+ "module_type": "PatchedPretransform",
211
+ "patch_size": 2,
212
+ },
213
+ {
214
+ "module_type": "Transformer",
215
+ "input_dimension": 640,
216
+ "output_dimension": 768,
217
+ "d_model": 768,
218
+ "num_heads": 12,
219
+ "num_layers": 12,
220
+ "dim_feedforward": 3072,
221
+ "causal": True,
222
+ "norm": "layer_norm",
223
+ "positional_embedding": "rope",
224
+ "max_period": 10000,
225
+ "gating": "none",
226
+ "layer_scale": 0.01,
227
+ "conv_layout": True,
228
+ },
229
+ {
230
+ "module_type": "PatchedPretransform",
231
+ "patch_size": 2,
232
+ },
233
+ {
234
+ "module_type": "Transformer",
235
+ "input_dimension": 384,
236
+ "output_dimension": 768,
237
+ "d_model": 768,
238
+ "num_heads": 12,
239
+ "num_layers": 12,
240
+ "dim_feedforward": 3072,
241
+ "causal": True,
242
+ "norm": "layer_norm",
243
+ "positional_embedding": "rope",
244
+ "max_period": 10000,
245
+ "gating": "none",
246
+ "layer_scale": 0.01,
247
+ "conv_layout": True,
248
+ },
249
+ {
250
+ "module_type": "PatchedPretransform",
251
+ "patch_size": 2,
252
+ },
253
+ {
254
+ "module_type": "Transformer",
255
+ "input_dimension": 384,
256
+ "output_dimension": 768,
257
+ "d_model": 768,
258
+ "num_heads": 12,
259
+ "num_layers": 12,
260
+ "dim_feedforward": 3072,
261
+ "causal": True,
262
+ "norm": "layer_norm",
263
+ "positional_embedding": "rope",
264
+ "max_period": 10000,
265
+ "gating": "none",
266
+ "layer_scale": 0.01,
267
+ "conv_layout": True,
268
+ },
269
+ {
270
+ "module_type": "PatchedPretransform",
271
+ "patch_size": 2,
272
+ },
273
+ {
274
+ "module_type": "Transformer",
275
+ "input_dimension": 384,
276
+ "output_dimension": 240,
277
+ "d_model": 768,
278
+ "num_heads": 12,
279
+ "num_layers": 12,
280
+ "dim_feedforward": 3072,
281
+ "causal": True,
282
+ "norm": "layer_norm",
283
+ "positional_embedding": "rope",
284
+ "max_period": 10000,
285
+ "gating": "none",
286
+ "layer_scale": 0.01,
287
+ "conv_layout": True,
288
+ },
289
+ {
290
+ "module_type": "PatchedPretransform",
291
+ "patch_size": 240,
292
+ },
293
+ ]
294
+ self.decoder_kwargs = decoder_kwargs
295
+
296
+ # Default quantizer configuration
297
+ if quantizer_kwargs is None:
298
+ quantizer_kwargs = {
299
+ "input_dim": 768,
300
+ "rvq_dim": 512,
301
+ "output_dim": 768,
302
+ "num_quantizers": 32,
303
+ "codebook_size": 1024,
304
+ "codebook_dim": 8,
305
+ "quantizer_type": "rlfq",
306
+ }
307
+
308
+ # Handle quantizer_type from kwargs or config
309
+ kw_qtype = quantizer_kwargs.get("quantizer_type", None)
310
+ if kw_qtype is not None:
311
+ self.quantizer_type = kw_qtype
312
+ else:
313
+ self.quantizer_type = quantizer_type
314
+ quantizer_kwargs["quantizer_type"] = quantizer_type
315
+
316
+ self.quantizer_kwargs = quantizer_kwargs
317
+
318
+ super().__init__(**kwargs)
319
+
320
+ @property
321
+ def num_quantizers(self) -> int:
322
+ """Return the number of quantizers from quantizer_kwargs."""
323
+ return self.quantizer_kwargs.get("num_quantizers", 32)
324
+
325
+ @property
326
+ def codebook_size(self) -> int:
327
+ """Return the codebook size from quantizer_kwargs."""
328
+ return self.quantizer_kwargs.get("codebook_size", 4096)
329
+
330
+ @property
331
+ def frame_rate(self) -> float:
332
+ """Return the frame rate (tokens per second)."""
333
+ return self.sampling_rate / self.downsample_rate
334
+
335
+
336
+ __all__ = ["MossAudioTokenizerConfig"]
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:037f441ed30a0ab59f6049de83b824a1b3bd6feb7dbd46c3fbca41fc2f649f28
3
+ size 4998259168
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a187d73d2cda1c2d0676586d9d03c09c0a5813450266af32029c871493fc9582
3
+ size 2100202560
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_moss_audio_tokenizer.py ADDED
@@ -0,0 +1,1812 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch MossAudioTokenizer model."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import copy
20
+ import math
21
+ from contextlib import ExitStack, contextmanager
22
+ from dataclasses import dataclass
23
+ from typing import cast
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+
29
+ from transformers.modeling_utils import PreTrainedAudioTokenizerBase
30
+ from transformers.utils import ModelOutput, auto_docstring, logging
31
+
32
+ from .configuration_moss_audio_tokenizer import MossAudioTokenizerConfig
33
+
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+
38
+ # =============================================================================
39
+ # Output Classes
40
+ # =============================================================================
41
+
42
+
43
+ @dataclass
44
+ @auto_docstring
45
+ class MossAudioTokenizerEncoderOutput(ModelOutput):
46
+ r"""
47
+ audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`, *optional*):
48
+ Discrete audio codes computed using the encoder and quantizer.
49
+ audio_codes_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
50
+ Valid lengths for each sample's audio codes.
51
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, hidden_size, sequence_length)`, *optional*):
52
+ Hidden states from the encoder before quantization.
53
+ """
54
+
55
+ audio_codes: torch.Tensor | None = None
56
+ audio_codes_lengths: torch.Tensor | None = None
57
+ encoder_hidden_states: torch.Tensor | None = None
58
+
59
+
60
+ @dataclass
61
+ @auto_docstring
62
+ class MossAudioTokenizerDecoderOutput(ModelOutput):
63
+ r"""
64
+ audio (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
65
+ Decoded audio waveform.
66
+ audio_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
67
+ Valid lengths for each sample's audio.
68
+ """
69
+
70
+ audio: torch.Tensor | None = None
71
+ audio_lengths: torch.Tensor | None = None
72
+
73
+
74
+ @dataclass
75
+ @auto_docstring
76
+ class MossAudioTokenizerOutput(ModelOutput):
77
+ r"""
78
+ audio (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
79
+ Decoded audio waveform.
80
+ audio_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
81
+ Valid lengths for each sample's audio.
82
+ audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`, *optional*):
83
+ Discrete audio codes computed using the encoder and quantizer.
84
+ audio_codes_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
85
+ Valid lengths for each sample's audio codes.
86
+ """
87
+
88
+ audio: torch.Tensor | None = None
89
+ audio_lengths: torch.Tensor | None = None
90
+ audio_codes: torch.Tensor | None = None
91
+ audio_codes_lengths: torch.Tensor | None = None
92
+
93
+
94
+ # =============================================================================
95
+ # Streaming Module Base Classes
96
+ # =============================================================================
97
+
98
+
99
+ @dataclass
100
+ class StreamingState:
101
+ """Base state for streaming modules."""
102
+
103
+ batch_size: int
104
+ device: torch.device
105
+
106
+ def __post_init__(self):
107
+ self.exec_mask = torch.ones(self.batch_size, dtype=torch.bool, device=self.device)
108
+
109
+ def set_exec_mask(self, exec_mask: torch.Tensor):
110
+ self.exec_mask[:] = exec_mask
111
+
112
+ def reset(self, reset_mask: torch.Tensor) -> None:
113
+ self.exec_mask[:] = torch.where(reset_mask, torch.ones_like(self.exec_mask), self.exec_mask)
114
+
115
+ def __enter__(self):
116
+ # ExitStack expects a context manager; returning self is conventional and useful for debugging.
117
+ return self
118
+
119
+ def __exit__(self, exc_type, exc_value, traceback) -> None:
120
+ pass
121
+
122
+
123
+ class StreamingModule(nn.Module):
124
+ """Base class for streaming components."""
125
+
126
+ def __init__(self) -> None:
127
+ super().__init__()
128
+ self._streaming_state: StreamingState | None = None
129
+ self._streaming_detached: bool = False
130
+ self._cached_children: list[tuple[str, StreamingModule]] | None = None
131
+
132
+ @property
133
+ def is_streaming(self):
134
+ return self._streaming_state is not None
135
+
136
+ def _apply_named_streaming(self, fn):
137
+ def _handle_module(prefix: str, module: nn.Module):
138
+ if isinstance(module, StreamingModule):
139
+ if module._streaming_detached and prefix != "":
140
+ return
141
+ if self._cached_children is None:
142
+ raise RuntimeError("Internal error: _cached_children should be initialized before traversal.")
143
+ self._cached_children.append((prefix, module))
144
+ for name, child in module.named_children():
145
+ new_prefix = f"{prefix}.{name}" if prefix else name
146
+ _handle_module(new_prefix, child)
147
+
148
+ if self._cached_children is None:
149
+ self._cached_children = []
150
+ _handle_module("", self)
151
+ for name, child in self._cached_children:
152
+ fn(name, child)
153
+
154
+ def _start_streaming(self, batch_size: int, exit_stack: ExitStack):
155
+ def _start_streaming_fn(name: str, module: StreamingModule):
156
+ if module._streaming_state is not None:
157
+ raise RuntimeError(f"{name} is already streaming!")
158
+ state = module._init_streaming_state(batch_size)
159
+ exit_stack.enter_context(state)
160
+ module._streaming_state = state
161
+
162
+ self._apply_named_streaming(_start_streaming_fn)
163
+
164
+ def _stop_streaming(self) -> None:
165
+ def _stop_streaming_fn(name: str, module: StreamingModule):
166
+ module._streaming_state = None
167
+
168
+ self._apply_named_streaming(_stop_streaming_fn)
169
+
170
+ def _init_streaming_state(self, batch_size: int) -> StreamingState:
171
+ device = next(iter(self.parameters())).device
172
+ return StreamingState(batch_size, device)
173
+
174
+ def streaming(self, batch_size: int) -> ExitStack:
175
+ """Context manager to enter streaming mode."""
176
+ exit_stack = ExitStack()
177
+ self._start_streaming(batch_size, exit_stack)
178
+ exit_stack.callback(self._stop_streaming)
179
+ return exit_stack
180
+
181
+
182
+ class StreamingContainer(StreamingModule):
183
+ """Container for streaming modules."""
184
+
185
+ pass
186
+
187
+
188
+ # =============================================================================
189
+ # Normalization Layers
190
+ # =============================================================================
191
+
192
+
193
+ class MossAudioTokenizerRMSNorm(nn.Module):
194
+ """Root Mean Square Layer Normalization."""
195
+
196
+ def __init__(
197
+ self,
198
+ dim: int,
199
+ eps: float = 1e-5,
200
+ dtype: torch.dtype | None = None,
201
+ device=None,
202
+ ):
203
+ super().__init__()
204
+ self.eps = eps
205
+ self.dtype = dtype
206
+ self.alpha = nn.Parameter(torch.full((1, 1, dim), 1.0, requires_grad=True, device=device, dtype=dtype))
207
+
208
+ def forward(self, x: torch.Tensor):
209
+ x_dtype = x.dtype
210
+ if self.dtype is not None:
211
+ x = x.to(self.dtype)
212
+ var = self.eps + torch.mean(x**2, dim=2, keepdim=True)
213
+ y = (x * (self.alpha.to(var) * torch.rsqrt(var))).to(x_dtype)
214
+ return y
215
+
216
+
217
+ class MossAudioTokenizerLayerScale(nn.Module):
218
+ """Layer scale from Touvron et al. 2021."""
219
+
220
+ def __init__(
221
+ self,
222
+ channels: int,
223
+ init: float = 1e-4,
224
+ channel_last: bool = True,
225
+ device=None,
226
+ dtype=None,
227
+ ):
228
+ super().__init__()
229
+ self.channel_last = channel_last
230
+ self.scale = nn.Parameter(torch.full((channels,), init, requires_grad=True, device=device, dtype=dtype))
231
+
232
+ def forward(self, x: torch.Tensor):
233
+ if self.channel_last:
234
+ return self.scale * x
235
+ else:
236
+ return self.scale[:, None] * x
237
+
238
+
239
+ def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
240
+ """Create normalization module."""
241
+ if norm_type == "layer_norm":
242
+ return nn.LayerNorm(dim, eps=1e-5, **kwargs)
243
+ elif norm_type in {"rms_norm"}:
244
+ return MossAudioTokenizerRMSNorm(dim, eps=1e-5, **kwargs)
245
+ elif norm_type in {"rms_norm_f32"}:
246
+ kwargs.pop("dtype", None)
247
+ return MossAudioTokenizerRMSNorm(dim, eps=1e-8, dtype=torch.float, **kwargs)
248
+ else:
249
+ raise ValueError(f"Unknown norm type: {norm_type}")
250
+
251
+
252
+ # =============================================================================
253
+ # Rotary Position Embedding
254
+ # =============================================================================
255
+
256
+
257
+ def apply_rope(
258
+ q: torch.Tensor,
259
+ k: torch.Tensor,
260
+ offset: torch.Tensor,
261
+ max_period: float = 10_000,
262
+ time_before_heads: bool = False,
263
+ ):
264
+ """Apply rotary position embedding."""
265
+ if time_before_heads:
266
+ B, T, H, D = q.shape
267
+ else:
268
+ B, H, T, D = q.shape
269
+ if k.shape != q.shape:
270
+ raise ValueError(f"Expected k.shape == q.shape, got k={tuple(k.shape)} q={tuple(q.shape)}")
271
+ if D <= 0 or (D % 2) != 0:
272
+ raise ValueError(f"RoPE requires an even last dimension, got D={D}")
273
+
274
+ ds = torch.arange(D // 2, device=q.device, dtype=torch.float32)
275
+ freqs = torch.exp(ds * (-math.log(max_period) * 2 / D))
276
+ ts = offset.float().view(-1, 1) + torch.arange(T, device=q.device, dtype=torch.float32)
277
+
278
+ if time_before_heads:
279
+ ts = ts.view(B, -1, 1, 1)
280
+ else:
281
+ ts = ts.view(B, 1, -1, 1)
282
+
283
+ dims = q.shape[:-1]
284
+ q = q.view(*dims, D // 2, 2)
285
+ k = k.view(*dims, D // 2, 2)
286
+
287
+ qr, qi = q[..., 0].float(), q[..., 1].float()
288
+ kr, ki = k[..., 0].float(), k[..., 1].float()
289
+
290
+ rotr = torch.cos(freqs * ts)
291
+ roti = torch.sin(freqs * ts)
292
+
293
+ qor = qr * rotr - qi * roti
294
+ qoi = qr * roti + qi * rotr
295
+ kor = kr * rotr - ki * roti
296
+ koi = kr * roti + ki * rotr
297
+
298
+ dtype = q.dtype
299
+ qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1)
300
+ ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1)
301
+
302
+ return qo.view(*dims, D), ko.view(*dims, D)
303
+
304
+
305
+ class MossAudioTokenizerRotaryEmbedding(nn.Module):
306
+ """Rotary positional embedding (RoPE)."""
307
+
308
+ def __init__(self, max_period: float = 10000.0):
309
+ super().__init__()
310
+ self.max_period = max_period
311
+
312
+ def forward(
313
+ self,
314
+ q: torch.Tensor,
315
+ k: torch.Tensor,
316
+ offset: torch.Tensor,
317
+ time_before_heads: bool = False,
318
+ ):
319
+ return apply_rope(q, k, offset, self.max_period, time_before_heads)
320
+
321
+
322
+ # =============================================================================
323
+ # Gating Modules
324
+ # =============================================================================
325
+
326
+
327
+ class MossAudioTokenizerActivationGating(nn.Module):
328
+ """Gating FFN layer with activation."""
329
+
330
+ def __init__(self, dim: int, dim_feedforward: int, activation, **factory_kwargs):
331
+ super().__init__()
332
+ if dim_feedforward == 4 * dim:
333
+ hidden = (21 * dim) // 8
334
+ else:
335
+ hidden = (2 * dim_feedforward) // 3
336
+
337
+ self.linear_in = nn.Linear(dim, 2 * hidden, bias=False, **factory_kwargs)
338
+ self.linear_out = nn.Linear(hidden, dim, bias=False, **factory_kwargs)
339
+ self.activation = activation
340
+
341
+ def forward(self, x: torch.Tensor):
342
+ x = self.linear_in(x)
343
+ B, T, _ = x.shape
344
+ x = x.view(B, T, 2, -1)
345
+ x = self.activation(x[..., 0, :]) * x[..., 1, :]
346
+ x = self.linear_out(x)
347
+ return x
348
+
349
+
350
+ def _get_activation(name: str):
351
+ if name in ["sigmoid", "tanh", "relu"]:
352
+ return getattr(torch, name)
353
+ elif name in ["leaky_relu", "elu", "gelu", "silu", "mish", "softsign"]:
354
+ return getattr(F, name)
355
+ elif name == "identity":
356
+ return nn.Identity()
357
+ else:
358
+ raise ValueError(f"Unknown activation {name}")
359
+
360
+
361
+ def make_gating(name: str, dim: int, dim_feedforward: int, **factory_kwargs) -> nn.Module:
362
+ return MossAudioTokenizerActivationGating(dim, dim_feedforward, _get_activation(name), **factory_kwargs)
363
+
364
+
365
+ # =============================================================================
366
+ # Positional Embeddings
367
+ # =============================================================================
368
+
369
+
370
+ def create_sin_embedding(
371
+ positions: torch.Tensor,
372
+ dim: int,
373
+ max_period: float = 10000,
374
+ dtype: torch.dtype = torch.float32,
375
+ ):
376
+ """Create sinusoidal positional embedding with shape [B, T, C]."""
377
+ if dim % 2 != 0:
378
+ raise ValueError(f"Sinusoidal embedding requires even dim, got dim={dim}")
379
+ half_dim = dim // 2
380
+ if half_dim <= 1:
381
+ raise ValueError(f"Sinusoidal embedding requires dim >= 4, got dim={dim}")
382
+ positions = positions.to(dtype)
383
+ adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
384
+ max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype)
385
+ phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
386
+ return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
387
+
388
+
389
+ # =============================================================================
390
+ # KV Cache for Attention
391
+ # =============================================================================
392
+
393
+
394
+ class KVCacheResult:
395
+ """Container for KV cache results that supports tuple unpacking."""
396
+
397
+ __slots__ = ("keys", "values", "positions")
398
+
399
+ def __init__(self, keys: torch.Tensor, values: torch.Tensor, positions: torch.Tensor):
400
+ self.keys = keys
401
+ self.values = values
402
+ self.positions = positions
403
+
404
+ def __iter__(self):
405
+ """Allow unpacking as (keys, values, positions)."""
406
+ return iter((self.keys, self.values, self.positions))
407
+
408
+ @staticmethod
409
+ def from_kv(keys: torch.Tensor, values: torch.Tensor) -> "KVCacheResult":
410
+ B, H, T, D = keys.shape
411
+ positions = torch.arange(T, device=keys.device, dtype=torch.long)
412
+ return KVCacheResult(keys, values, positions.expand(B, -1))
413
+
414
+
415
+ class RingKVCache:
416
+ """Efficient streaming KVCache compatible with CUDA Graph."""
417
+
418
+ def __init__(
419
+ self,
420
+ batch_size: int,
421
+ num_heads: int,
422
+ dim_per_head: int,
423
+ capacity: int,
424
+ respect_exec_mask: bool = True,
425
+ device: torch.device = torch.device("cuda"),
426
+ dtype: torch.dtype = torch.bfloat16,
427
+ ):
428
+ self.capacity = capacity
429
+ self.cache = torch.zeros(
430
+ (2, batch_size, num_heads, capacity, dim_per_head),
431
+ device=device,
432
+ dtype=dtype,
433
+ )
434
+ self.respect_exec_mask = respect_exec_mask
435
+ if self.respect_exec_mask:
436
+ self.end_offset = torch.zeros(batch_size, device=device, dtype=torch.long)
437
+ else:
438
+ self.end_offset = torch.zeros(1, device=device, dtype=torch.long)
439
+
440
+ def reset(self, reset_mask: torch.Tensor) -> None:
441
+ self.end_offset[:] = torch.where(reset_mask, torch.zeros_like(self.end_offset), self.end_offset)
442
+
443
+ def complete(self, k: torch.Tensor, v: torch.Tensor, exec_mask: torch.Tensor) -> KVCacheResult:
444
+ B, H, T, D = k.shape
445
+ if T <= 0:
446
+ raise ValueError(f"Expected T > 0, got T={T}")
447
+
448
+ indexes = torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype)
449
+ indexes = indexes + self.end_offset.view(-1, 1)
450
+ indexes = indexes % self.capacity
451
+
452
+ if self.respect_exec_mask:
453
+ this_indexes = indexes.view(B, 1, T, 1).expand(-1, H, T, D)
454
+ self.cache[0].scatter_(2, this_indexes, k)
455
+ self.cache[1].scatter_(2, this_indexes, v)
456
+ else:
457
+ self.cache[0].index_copy_(2, indexes[0], k)
458
+ self.cache[1].index_copy_(2, indexes[0], v)
459
+
460
+ keys = self.cache[0]
461
+ values = self.cache[1]
462
+
463
+ indexes = torch.arange(self.capacity, device=self.end_offset.device, dtype=torch.long)
464
+ last_offset = self.end_offset.view(-1, 1) + T - 1
465
+ end_index = last_offset % self.capacity
466
+ delta = indexes - end_index
467
+
468
+ positions = torch.where(
469
+ delta <= 0,
470
+ last_offset + delta,
471
+ last_offset + delta - self.capacity,
472
+ )
473
+
474
+ if self.respect_exec_mask:
475
+ self.end_offset[:] = torch.where(exec_mask, self.end_offset + T, self.end_offset)
476
+ else:
477
+ self.end_offset.add_(T)
478
+
479
+ invalid = indexes >= self.end_offset.view(-1, 1)
480
+ positions = torch.where(invalid, torch.full_like(positions, -1), positions)
481
+
482
+ return KVCacheResult(keys, values, positions)
483
+
484
+
485
+ # =============================================================================
486
+ # Multi-Head Attention
487
+ # =============================================================================
488
+
489
+
490
+ @dataclass
491
+ class MHAState(StreamingState):
492
+ kv_cache: RingKVCache | None
493
+ offset: torch.Tensor
494
+ offset_cpu: int
495
+
496
+ def reset(self, reset_mask: torch.Tensor):
497
+ super().reset(reset_mask)
498
+ self.offset[:] = torch.where(reset_mask, torch.zeros_like(self.offset), self.offset)
499
+ if self.kv_cache is not None:
500
+ self.kv_cache.reset(reset_mask)
501
+ self.offset_cpu = 0
502
+
503
+
504
+ def apply_weights_per_step(
505
+ modules: nn.ModuleList,
506
+ schedule: list[int] | None,
507
+ x: torch.Tensor,
508
+ offset: int | None,
509
+ ):
510
+ """Apply different weights for each time step."""
511
+ if len(modules) == 1:
512
+ return modules[0](x)
513
+
514
+ if offset is None:
515
+ raise ValueError("offset must be provided when using per-step weights (len(modules) > 1).")
516
+ ys = []
517
+ B, T, C = x.shape
518
+ for t in range(T):
519
+ module_index = t + offset
520
+ if schedule is not None:
521
+ if module_index >= len(schedule) or module_index < 0:
522
+ raise ValueError(
523
+ f"weights_per_step_schedule is too short for module_index={module_index} (len={len(schedule)})."
524
+ )
525
+ module_index = schedule[module_index]
526
+ if module_index >= len(modules) or module_index < 0:
527
+ raise ValueError(f"module_index={module_index} out of range for len(modules)={len(modules)}.")
528
+ y = modules[module_index](x[:, t : t + 1])
529
+ ys.append(y)
530
+ return torch.cat(ys, 1)
531
+
532
+
533
+ class MossAudioTokenizerMultiheadAttention(StreamingModule):
534
+ """Multi-head attention with streaming support."""
535
+
536
+ def __init__(
537
+ self,
538
+ embed_dim: int,
539
+ num_heads: int,
540
+ causal: bool = False,
541
+ context: int | None = None,
542
+ rope: MossAudioTokenizerRotaryEmbedding | None = None,
543
+ weights_per_step: int = 0,
544
+ weights_per_step_schedule: list[int] | None = None,
545
+ device=None,
546
+ dtype=None,
547
+ ):
548
+ super().__init__()
549
+ factory_kwargs = {"device": device, "dtype": dtype}
550
+
551
+ self.embed_dim = embed_dim
552
+ self.causal = causal
553
+ self.context = context
554
+ self.rope = rope
555
+ self.num_heads = num_heads
556
+ self.weights_per_step = weights_per_step
557
+ self.weights_per_step_schedule = weights_per_step_schedule
558
+
559
+ out_dim = 3 * embed_dim
560
+ mult = 1
561
+ if weights_per_step:
562
+ mult = max(weights_per_step_schedule) + 1 if weights_per_step_schedule else weights_per_step
563
+ self.mult = mult
564
+
565
+ self.out_projs = nn.ModuleList(
566
+ [nn.Linear(embed_dim, embed_dim, bias=False, **factory_kwargs) for _ in range(mult)]
567
+ )
568
+ self.in_projs = nn.ModuleList(
569
+ [nn.Linear(embed_dim, out_dim, bias=False, **factory_kwargs) for _ in range(mult)]
570
+ )
571
+
572
+ self._register_load_state_dict_pre_hook(self._load_hook, with_module=True)
573
+
574
+ @staticmethod
575
+ def _load_hook(module, state_dict, prefix, *_):
576
+ mappings = {
577
+ "in_proj_weight": "in_projs.{i}.weight",
578
+ "in_proj.weight": "in_projs.{i}.weight",
579
+ "out_proj.weight": "out_projs.{i}.weight",
580
+ }
581
+ mult = module.mult
582
+ for suffix in ["", "_scb"]:
583
+ for source, target in mappings.items():
584
+ this_source = prefix + source + suffix
585
+ if this_source in state_dict:
586
+ weight = state_dict[this_source]
587
+ _, *OD = weight.shape
588
+ weight = weight.view(mult, -1, *OD)
589
+ for i in range(mult):
590
+ state_dict[prefix + target.format(i=i) + suffix] = weight[i]
591
+ state_dict.pop(this_source)
592
+
593
+ def _init_streaming_state(self, batch_size: int) -> MHAState:
594
+ in_proj = cast(nn.Linear, self.in_projs[0])
595
+ device = cast(torch.device, in_proj.weight.device)
596
+ dtype = cast(torch.dtype, in_proj.weight.dtype)
597
+
598
+ dim_per_head = self.embed_dim // self.num_heads
599
+ if self.context is None:
600
+ capacity = self.weights_per_step if self.weights_per_step else 1024
601
+ else:
602
+ capacity = self.context
603
+
604
+ kv_cache = RingKVCache(
605
+ batch_size,
606
+ self.num_heads,
607
+ dim_per_head,
608
+ capacity,
609
+ respect_exec_mask=not self.weights_per_step,
610
+ device=cast(torch.device, device),
611
+ dtype=cast(torch.dtype, dtype),
612
+ )
613
+ return MHAState(
614
+ batch_size,
615
+ cast(torch.device, device),
616
+ kv_cache,
617
+ offset=torch.zeros(batch_size, device=cast(torch.device, device), dtype=torch.long),
618
+ offset_cpu=0,
619
+ )
620
+
621
+ def _complete_kv(self, k, v) -> KVCacheResult:
622
+ state = cast(MHAState | None, self._streaming_state)
623
+ if state is None:
624
+ return KVCacheResult.from_kv(k, v)
625
+ if state.kv_cache is None:
626
+ return KVCacheResult.from_kv(k, v)
627
+ return state.kv_cache.complete(k, v, state.exec_mask)
628
+
629
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
630
+ state = cast(MHAState | None, self._streaming_state)
631
+ B, T = query.shape[:2]
632
+
633
+ if state is None:
634
+ offset = torch.zeros(B, device=query.device, dtype=torch.long)
635
+ offset_cpu = 0
636
+ else:
637
+ offset = state.offset
638
+ offset_cpu = state.offset_cpu
639
+
640
+ projected = apply_weights_per_step(self.in_projs, self.weights_per_step_schedule, query, offset_cpu)
641
+ dim_per_head = self.embed_dim // self.num_heads
642
+ projected = projected.reshape(B, T, 3, self.num_heads, dim_per_head).permute(2, 0, 3, 1, 4)
643
+ q, k, v = projected[0], projected[1], projected[2]
644
+
645
+ if self.rope:
646
+ q, k = self.rope(q, k, offset, time_before_heads=False)
647
+
648
+ k, v, pos_k = self._complete_kv(k, v)
649
+ pos_k = pos_k[:, None]
650
+
651
+ if self.causal:
652
+ pos_q = offset.view(-1, 1, 1) + torch.arange(T, device=q.device, dtype=torch.long).view(-1, 1)
653
+ delta = pos_q - pos_k
654
+ attn_bias = (pos_k >= 0) & (delta >= 0)
655
+ if self.context is not None:
656
+ attn_bias = attn_bias & (delta < self.context)
657
+ attn_bias = attn_bias[:, None]
658
+ else:
659
+ attn_bias = None
660
+
661
+ x = F.scaled_dot_product_attention(q, k, v, attn_bias, dropout_p=0.0)
662
+ x = x.transpose(1, 2).reshape(B, T, self.embed_dim)
663
+ x = apply_weights_per_step(self.out_projs, self.weights_per_step_schedule, x, offset_cpu)
664
+
665
+ if state is not None:
666
+ state.offset[:] = torch.where(state.exec_mask, state.offset + T, state.offset)
667
+ state.offset_cpu += T
668
+ return x
669
+
670
+
671
+ # =============================================================================
672
+ # Transformer Layer
673
+ # =============================================================================
674
+
675
+
676
+ @dataclass
677
+ class LayerState(StreamingState):
678
+ offset_cpu: int = 0
679
+
680
+ def reset(self, reset_mask: torch.Tensor):
681
+ super().reset(reset_mask)
682
+ self.offset_cpu = 0
683
+
684
+
685
+ class MossAudioTokenizerTransformerLayer(StreamingModule):
686
+ """Transformer layer with streaming support."""
687
+
688
+ def __init__(
689
+ self,
690
+ d_model: int,
691
+ num_heads: int,
692
+ dim_feedforward: int = 2048,
693
+ causal: bool = False,
694
+ context: int | None = None,
695
+ rope: MossAudioTokenizerRotaryEmbedding | None = None,
696
+ norm: str = "layer_norm",
697
+ layer_scale: float | None = None,
698
+ gating: str = "none",
699
+ weights_per_step: int = 0,
700
+ weights_per_step_schedule: list[int] | None = None,
701
+ activation=F.gelu,
702
+ device=None,
703
+ dtype=None,
704
+ ):
705
+ super().__init__()
706
+ factory_kwargs = {"device": device, "dtype": dtype}
707
+
708
+ self.self_attn = MossAudioTokenizerMultiheadAttention(
709
+ embed_dim=d_model,
710
+ num_heads=num_heads,
711
+ causal=causal,
712
+ context=context,
713
+ rope=rope,
714
+ weights_per_step=weights_per_step,
715
+ weights_per_step_schedule=weights_per_step_schedule,
716
+ **factory_kwargs,
717
+ )
718
+ self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs)
719
+ self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs)
720
+
721
+ self.weights_per_step = weights_per_step
722
+ self.weights_per_step_schedule = weights_per_step_schedule
723
+ self.gating: nn.Module | nn.ModuleList | None = None
724
+ self.linear1: nn.Module | None = None
725
+ self.linear2: nn.Module | None = None
726
+ self.activation = activation
727
+
728
+ num_weights = 1
729
+ if weights_per_step:
730
+ num_weights = max(weights_per_step_schedule) + 1 if weights_per_step_schedule else weights_per_step
731
+
732
+ if gating == "none":
733
+ self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False, **factory_kwargs)
734
+ self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False, **factory_kwargs)
735
+ else:
736
+ if weights_per_step:
737
+ dim_ff_list = [dim_feedforward] * num_weights if isinstance(dim_feedforward, int) else dim_feedforward
738
+ self.gating = nn.ModuleList(
739
+ [make_gating(gating, d_model, dim, **factory_kwargs) for dim in dim_ff_list]
740
+ )
741
+ else:
742
+ self.gating = make_gating(gating, d_model, dim_feedforward, **factory_kwargs)
743
+
744
+ if layer_scale is None:
745
+ self.layer_scale_1 = nn.Identity()
746
+ self.layer_scale_2 = nn.Identity()
747
+ else:
748
+ self.layer_scale_1 = MossAudioTokenizerLayerScale(
749
+ channels=d_model, init=layer_scale, channel_last=True, **cast(dict[str, object], factory_kwargs)
750
+ )
751
+ self.layer_scale_2 = MossAudioTokenizerLayerScale(
752
+ channels=d_model, init=layer_scale, channel_last=True, **cast(dict[str, object], factory_kwargs)
753
+ )
754
+
755
+ def _init_streaming_state(self, batch_size: int) -> LayerState:
756
+ device = next(iter(self.parameters())).device
757
+ return LayerState(batch_size, device, offset_cpu=0)
758
+
759
+ def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
760
+ state = self._streaming_state
761
+ offset = state.offset_cpu if isinstance(state, LayerState) else 0
762
+
763
+ x_orig = x
764
+ x = self.norm2(x)
765
+
766
+ if self.gating is None:
767
+ assert self.linear1 is not None
768
+ assert self.linear2 is not None
769
+ update = self.linear2(self.activation(self.linear1(x)))
770
+ else:
771
+ if self.weights_per_step:
772
+ assert isinstance(self.gating, nn.ModuleList)
773
+ update = apply_weights_per_step(self.gating, self.weights_per_step_schedule, x, offset)
774
+ else:
775
+ update = self.gating(x)
776
+ return x_orig.to(update) + self.layer_scale_2(update)
777
+
778
+ def _sa_block(self, x: torch.Tensor):
779
+ x_orig = x
780
+ x = self.norm1(x)
781
+ update = self.self_attn(x, x, x)
782
+ return x_orig.to(update) + self.layer_scale_1(update)
783
+
784
+ def forward(self, x: torch.Tensor):
785
+ x = self._sa_block(x)
786
+ x = self._ff_block(x)
787
+ state = self._streaming_state
788
+ if state is not None:
789
+ assert isinstance(state, LayerState)
790
+ state.offset_cpu += x.shape[1]
791
+ return x
792
+
793
+
794
+ # =============================================================================
795
+ # Streaming Transformer
796
+ # =============================================================================
797
+
798
+
799
+ @dataclass
800
+ class TransformerState(StreamingState):
801
+ offsets: torch.Tensor
802
+
803
+ def reset(self, reset_mask: torch.Tensor):
804
+ super().reset(reset_mask)
805
+ self.offsets[:] = torch.where(reset_mask, torch.zeros_like(self.offsets), self.offsets)
806
+
807
+
808
+ class MossAudioTokenizerTransformer(StreamingModule):
809
+ """Transformer with streaming/causal support."""
810
+
811
+ def __init__(
812
+ self,
813
+ d_model: int,
814
+ num_heads: int,
815
+ num_layers: int,
816
+ dim_feedforward: int = 2048,
817
+ causal: bool = False,
818
+ context: int | None = None,
819
+ positional_embedding: str = "sin",
820
+ max_period: float = 10_000,
821
+ positional_scale: float = 1.0,
822
+ device=None,
823
+ dtype=None,
824
+ **kwargs,
825
+ ):
826
+ super().__init__()
827
+ if d_model % num_heads != 0:
828
+ raise ValueError(f"d_model must be divisible by num_heads, got d_model={d_model}, num_heads={num_heads}")
829
+
830
+ self.positional_embedding = positional_embedding
831
+ self.max_period = max_period
832
+ self.positional_scale = positional_scale
833
+
834
+ self.rope: MossAudioTokenizerRotaryEmbedding | None = None
835
+ if positional_embedding in {"rope", "sin_rope"}:
836
+ self.rope = MossAudioTokenizerRotaryEmbedding(max_period=max_period)
837
+
838
+ self.layers = nn.ModuleList()
839
+ for _ in range(num_layers):
840
+ self.layers.append(
841
+ MossAudioTokenizerTransformerLayer(
842
+ d_model=d_model,
843
+ num_heads=num_heads,
844
+ dim_feedforward=dim_feedforward,
845
+ causal=causal,
846
+ context=context,
847
+ rope=self.rope,
848
+ device=device,
849
+ dtype=dtype,
850
+ **kwargs,
851
+ )
852
+ )
853
+
854
+ def _init_streaming_state(self, batch_size: int) -> TransformerState:
855
+ device = next(self.parameters()).device
856
+ return TransformerState(
857
+ batch_size,
858
+ device,
859
+ offsets=torch.zeros(batch_size, device=device, dtype=torch.long),
860
+ )
861
+
862
+ def forward(self, x: torch.Tensor, *args, **kwargs):
863
+ B, T, C = x.shape
864
+ state = self._streaming_state
865
+ offsets = (
866
+ torch.zeros(1, dtype=torch.long, device=x.device)
867
+ if state is None
868
+ else (
869
+ state.offsets
870
+ if isinstance(state, TransformerState)
871
+ else torch.zeros(1, dtype=torch.long, device=x.device)
872
+ )
873
+ )
874
+
875
+ if self.positional_embedding in {"sin", "sin_rope"}:
876
+ positions = torch.arange(T, device=x.device).view(1, -1, 1)
877
+ positions = positions + offsets.view(-1, 1, 1)
878
+ pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
879
+ x = x + self.positional_scale * pos_emb
880
+
881
+ for layer in self.layers:
882
+ x = layer(x, *args, **kwargs)
883
+
884
+ if state is not None:
885
+ assert isinstance(state, TransformerState)
886
+ state.offsets[:] = torch.where(state.exec_mask, state.offsets + T, state.offsets)
887
+ return x
888
+
889
+
890
+ class MossAudioTokenizerProjectedTransformer(StreamingContainer):
891
+ """Transformer with input/output projections."""
892
+
893
+ def __init__(
894
+ self,
895
+ input_dimension: int,
896
+ output_dimension: int,
897
+ d_model: int,
898
+ *,
899
+ conv_layout: bool = False,
900
+ module_type: str,
901
+ **kwargs,
902
+ ):
903
+ super().__init__()
904
+ self.module_type = module_type
905
+ self.downsample_ratio: int = 1
906
+ self.input_dimension = input_dimension
907
+ self.output_dimension = output_dimension
908
+
909
+ self.input_proj = (
910
+ nn.Linear(input_dimension, d_model, bias=False) if d_model != input_dimension else nn.Identity()
911
+ )
912
+ self.transformer = MossAudioTokenizerTransformer(d_model=d_model, **kwargs)
913
+ self.conv_layout = conv_layout
914
+ self.output_proj = (
915
+ nn.Linear(d_model, output_dimension, bias=False) if d_model != output_dimension else nn.Identity()
916
+ )
917
+
918
+ def forward(self, x, input_lengths, *args, **kwargs):
919
+ x = self.input_proj(x.transpose(1, 2)) # (B, D, T) -> (B, T, D)
920
+ x = self.transformer(x, *args, **kwargs)
921
+ x = self.output_proj(x).transpose(1, 2) # (B, T, D) -> (B, D, T)
922
+ return x, input_lengths
923
+
924
+
925
+ # =============================================================================
926
+ # Patched Pretransform Module
927
+ # =============================================================================
928
+
929
+
930
+ class MossAudioTokenizerPatchedPretransform(nn.Module):
931
+ """Patching module for downsampling/upsampling."""
932
+
933
+ def __init__(self, patch_size: int, is_downsample: bool, module_type: str, **kwargs):
934
+ super().__init__()
935
+ self.patch_size = patch_size
936
+ self.downsample_ratio: int = patch_size
937
+ self.is_downsample = is_downsample
938
+ self.module_type = module_type
939
+
940
+ def encode(self, x, input_lengths):
941
+ b, d, _ = x.shape
942
+ h = self.patch_size
943
+ x = x.reshape(b, d, -1, h).permute(0, 1, 3, 2).reshape(b, d * h, -1)
944
+ # We pad the input waveform to a multiple of `downsample_rate` before applying the encoder.
945
+ # Use a ceil division to match that padding and avoid dropping the last (partially padded) frame.
946
+ output_lengths = (input_lengths + self.patch_size - 1) // self.patch_size
947
+ return x, output_lengths
948
+
949
+ def decode(self, x, input_lengths):
950
+ b, dh, l = x.shape
951
+ h = self.patch_size
952
+ d = dh // h
953
+ x = x.reshape(b, d, h, l).permute(0, 1, 3, 2).reshape(b, d, l * h)
954
+ output_lengths = input_lengths * self.patch_size
955
+ return x, output_lengths
956
+
957
+ def forward(self, x, input_lengths):
958
+ if self.is_downsample:
959
+ return self.encode(x, input_lengths)
960
+ else:
961
+ return self.decode(x, input_lengths)
962
+
963
+
964
+ # =============================================================================
965
+ # Vector Quantization
966
+ # =============================================================================
967
+
968
+
969
+ def WNConv1d(*args, **kwargs):
970
+ """Weight-normalized Conv1d."""
971
+ return nn.utils.parametrizations.weight_norm(nn.Conv1d(*args, **kwargs))
972
+
973
+
974
+ class MossAudioTokenizerVectorQuantize(nn.Module):
975
+ """Single codebook vector quantization (inference only)."""
976
+
977
+ def __init__(
978
+ self,
979
+ input_dim: int,
980
+ codebook_size: int,
981
+ codebook_dim: int,
982
+ **kwargs,
983
+ ):
984
+ super().__init__()
985
+ self.input_dim = input_dim
986
+ self.codebook_size = codebook_size
987
+ self.codebook_dim = codebook_dim
988
+
989
+ if input_dim != codebook_dim:
990
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
991
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
992
+ else:
993
+ self.in_proj = nn.Identity()
994
+ self.out_proj = nn.Identity()
995
+
996
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
997
+
998
+ @torch.no_grad()
999
+ def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1000
+ """
1001
+ Args:
1002
+ z: Input tensor of shape (B, D, T)
1003
+ Returns:
1004
+ z_q: Quantized tensor of shape (B, D, T)
1005
+ indices: Code indices of shape (B, T)
1006
+ z_e: Encoded tensor before quantization
1007
+ """
1008
+ z = z.float()
1009
+ z_e = self.in_proj(z).float()
1010
+
1011
+ encodings = z_e.transpose(1, 2).reshape(-1, z_e.shape[1])
1012
+
1013
+ codebook_weight = self.codebook.weight
1014
+ dist = (
1015
+ encodings.pow(2).sum(1, keepdim=True)
1016
+ - 2 * encodings @ codebook_weight.float().t()
1017
+ + codebook_weight.float().pow(2).sum(1, keepdim=True).t()
1018
+ )
1019
+
1020
+ indices = (-dist).max(1)[1]
1021
+ indices = indices.reshape(z.size(0), -1)
1022
+
1023
+ z_q = self.decode_code(indices)
1024
+ z_q = self.out_proj(z_q).float()
1025
+
1026
+ return z_q, indices, z_e
1027
+
1028
+ def decode_code(self, embed_id: torch.Tensor) -> torch.Tensor:
1029
+ """Decode code indices to embeddings."""
1030
+ return self.codebook(embed_id).transpose(1, 2).float()
1031
+
1032
+
1033
+ class MossAudioTokenizerLFQ(nn.Module):
1034
+ """LFQ (inference-only) used by ResidualLFQ."""
1035
+
1036
+ def __init__(
1037
+ self,
1038
+ input_dim: int,
1039
+ codebook_size: int,
1040
+ codebook_dim: int,
1041
+ **kwargs,
1042
+ ):
1043
+ super().__init__()
1044
+ self.input_dim = input_dim
1045
+ self.codebook_size = codebook_size
1046
+ self.codebook_dim = codebook_dim
1047
+
1048
+ if self.input_dim != self.codebook_dim:
1049
+ self.in_proj = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
1050
+ self.out_proj = WNConv1d(self.codebook_dim, self.input_dim, kernel_size=1)
1051
+ else:
1052
+ self.in_proj = nn.Identity()
1053
+ self.out_proj = nn.Identity()
1054
+
1055
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
1056
+
1057
+ @torch.no_grad()
1058
+ def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1059
+ """Quantize z into codebook vectors."""
1060
+ z = z.float()
1061
+ z_e = self.in_proj(z).float()
1062
+ z_q, indices = self.decode_latents(z_e)
1063
+ z_q = (z_e + (z_q - z_e).detach()).float()
1064
+ z_q = self.out_proj(z_q).float()
1065
+ return z_q, indices, z_e
1066
+
1067
+ def embed_code(self, embed_id: torch.Tensor) -> torch.Tensor:
1068
+ return F.embedding(embed_id, self.codebook.weight)
1069
+
1070
+ def decode_code_wo_out_proj(self, embed_id: torch.Tensor) -> torch.Tensor:
1071
+ return self.embed_code(embed_id).transpose(1, 2)
1072
+
1073
+ def decode_code(self, embed_id: torch.Tensor) -> torch.Tensor:
1074
+ z_q = self.decode_code_wo_out_proj(embed_id).float()
1075
+ z_q = self.out_proj(z_q).float()
1076
+ return z_q
1077
+
1078
+ def decode_latents(self, latents: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
1079
+ """Match training LFQ: L2-normalize then argmin squared distance."""
1080
+ encodings = latents.transpose(1, 2).reshape(-1, latents.shape[1]).float()
1081
+ codebook = self.codebook.weight.float()
1082
+
1083
+ encodings = F.normalize(encodings)
1084
+ codebook = F.normalize(codebook)
1085
+
1086
+ dist = (
1087
+ encodings.pow(2).sum(1, keepdim=True)
1088
+ - 2 * encodings @ codebook.t()
1089
+ + codebook.pow(2).sum(1, keepdim=True).t()
1090
+ )
1091
+
1092
+ indices = (-dist).max(1)[1]
1093
+ indices = indices.reshape(latents.size(0), -1)
1094
+ z_q = self.decode_code_wo_out_proj(indices).float()
1095
+ return z_q, indices
1096
+
1097
+
1098
+ class MossAudioTokenizerResidualVQ(nn.Module):
1099
+ """Residual Vector Quantization (inference only)."""
1100
+
1101
+ def __init__(
1102
+ self,
1103
+ input_dim: int = 1024,
1104
+ rvq_dim: int | None = None,
1105
+ output_dim: int | None = None,
1106
+ num_quantizers: int = 32,
1107
+ codebook_size: int = 1024,
1108
+ codebook_dim: int = 8,
1109
+ **kwargs,
1110
+ ):
1111
+ super().__init__()
1112
+ self.input_dim = input_dim
1113
+ self.rvq_dim = rvq_dim or input_dim
1114
+ self.output_dim = output_dim or input_dim
1115
+ self.num_quantizers = num_quantizers
1116
+ self.codebook_size = codebook_size
1117
+ self.codebook_dim = codebook_dim
1118
+
1119
+ self.input_proj = (
1120
+ WNConv1d(input_dim, self.rvq_dim, kernel_size=1) if input_dim != self.rvq_dim else nn.Identity()
1121
+ )
1122
+ self.output_proj = (
1123
+ WNConv1d(self.rvq_dim, self.output_dim, kernel_size=1)
1124
+ if self.rvq_dim != self.output_dim
1125
+ else nn.Identity()
1126
+ )
1127
+
1128
+ self.quantizers = nn.ModuleList(
1129
+ [
1130
+ MossAudioTokenizerVectorQuantize(
1131
+ input_dim=self.rvq_dim,
1132
+ codebook_size=codebook_size,
1133
+ codebook_dim=codebook_dim,
1134
+ **kwargs,
1135
+ )
1136
+ for _ in range(num_quantizers)
1137
+ ]
1138
+ )
1139
+
1140
+ @torch.no_grad()
1141
+ def forward(
1142
+ self,
1143
+ z: torch.Tensor,
1144
+ input_length: torch.Tensor,
1145
+ n_quantizers: int | None = None,
1146
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1147
+ """
1148
+ Args:
1149
+ z: Input tensor of shape (B, D, T)
1150
+ input_length: Valid lengths for each sample (B,)
1151
+ n_quantizers: Number of quantizers to use
1152
+ Returns:
1153
+ quantized_out: Quantized output (B, D, T)
1154
+ all_indices: All code indices (N, B, T)
1155
+ output_length: Output lengths (B,)
1156
+ """
1157
+ z = self.input_proj(z)
1158
+
1159
+ batch_size, _, max_time = z.shape
1160
+ mask = torch.arange(max_time, device=z.device).expand(batch_size, max_time) < input_length.unsqueeze(1)
1161
+
1162
+ quantized_out = torch.zeros_like(z, dtype=torch.float32)
1163
+ residual = z.clone().float()
1164
+ all_indices = []
1165
+
1166
+ n_quantizers = n_quantizers or self.num_quantizers
1167
+
1168
+ for i, quantizer in enumerate(self.quantizers):
1169
+ if i >= n_quantizers:
1170
+ break
1171
+
1172
+ masked_residual = residual * mask.unsqueeze(1)
1173
+ z_q_i, indices_i, _ = quantizer(masked_residual)
1174
+
1175
+ update_mask = mask.unsqueeze(1)
1176
+ quantized_out = quantized_out + z_q_i * update_mask
1177
+ residual = residual - z_q_i * update_mask
1178
+ all_indices.append(indices_i)
1179
+
1180
+ all_indices = torch.stack(all_indices) # (N, B, T)
1181
+ quantized_out = self.output_proj(quantized_out)
1182
+
1183
+ return quantized_out, all_indices, input_length
1184
+
1185
+ def decode_codes(self, codes: torch.Tensor) -> torch.Tensor:
1186
+ """Decode codes from multiple quantizers to embeddings."""
1187
+ nq, B, T = codes.shape
1188
+ emb = torch.zeros(B, self.rvq_dim, T, device=codes.device, dtype=torch.float32)
1189
+
1190
+ for i, quantizer in enumerate(self.quantizers[:nq]):
1191
+ quantizer = cast(MossAudioTokenizerVectorQuantize, quantizer)
1192
+ quantized_i = quantizer.decode_code(codes[i])
1193
+ emb += quantized_i
1194
+
1195
+ emb = self.output_proj(emb)
1196
+ return emb
1197
+
1198
+
1199
+ class MossAudioTokenizerResidualLFQ(nn.Module):
1200
+ """Residual LFQ (inference only)."""
1201
+
1202
+ def __init__(
1203
+ self,
1204
+ input_dim: int = 1024,
1205
+ rvq_dim: int | None = None,
1206
+ output_dim: int | None = None,
1207
+ num_quantizers: int = 32,
1208
+ codebook_size: int = 1024,
1209
+ codebook_dim: int = 8,
1210
+ **kwargs,
1211
+ ):
1212
+ super().__init__()
1213
+ self.input_dim = input_dim
1214
+ self.rvq_dim = rvq_dim or input_dim
1215
+ self.output_dim = output_dim or input_dim
1216
+ self.num_quantizers = num_quantizers
1217
+ self.codebook_size = codebook_size
1218
+ self.codebook_dim = codebook_dim
1219
+
1220
+ self.input_proj = (
1221
+ WNConv1d(input_dim, self.rvq_dim, kernel_size=1) if input_dim != self.rvq_dim else nn.Identity()
1222
+ )
1223
+ self.output_proj = (
1224
+ WNConv1d(self.rvq_dim, self.output_dim, kernel_size=1)
1225
+ if self.rvq_dim != self.output_dim
1226
+ else nn.Identity()
1227
+ )
1228
+
1229
+ self.quantizers = nn.ModuleList(
1230
+ [
1231
+ MossAudioTokenizerLFQ(
1232
+ input_dim=self.rvq_dim,
1233
+ codebook_size=codebook_size,
1234
+ codebook_dim=codebook_dim,
1235
+ **kwargs,
1236
+ )
1237
+ for _ in range(num_quantizers)
1238
+ ]
1239
+ )
1240
+
1241
+ @torch.no_grad()
1242
+ def forward(
1243
+ self,
1244
+ z: torch.Tensor,
1245
+ input_length: torch.Tensor,
1246
+ n_quantizers: int | None = None,
1247
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1248
+ """Inference quantization."""
1249
+ z = self.input_proj(z).float()
1250
+
1251
+ batch_size, _, max_time = z.shape
1252
+ mask = torch.arange(max_time, device=z.device).expand(batch_size, max_time) < input_length.unsqueeze(1)
1253
+
1254
+ quantized_out = torch.zeros_like(z, dtype=torch.float32)
1255
+ residual = z.clone().float()
1256
+ all_indices = []
1257
+
1258
+ n_quantizers = n_quantizers or self.num_quantizers
1259
+ for i, quantizer in enumerate(self.quantizers):
1260
+ if i >= n_quantizers:
1261
+ break
1262
+
1263
+ masked_residual = residual * mask.unsqueeze(1)
1264
+ z_q_i, indices_i, _ = quantizer(masked_residual)
1265
+
1266
+ update_mask = mask.unsqueeze(1)
1267
+ quantized_out = quantized_out + z_q_i * update_mask
1268
+ residual = residual - z_q_i * update_mask
1269
+ all_indices.append(indices_i)
1270
+
1271
+ all_indices = (
1272
+ torch.stack(all_indices)
1273
+ if all_indices
1274
+ else torch.empty(0, batch_size, max_time, device=z.device, dtype=torch.long)
1275
+ )
1276
+ quantized_out = self.output_proj(quantized_out)
1277
+ return quantized_out, all_indices, input_length
1278
+
1279
+ def decode_codes(self, codes: torch.Tensor) -> torch.Tensor:
1280
+ nq, B, T = codes.shape
1281
+ emb = torch.zeros(B, self.rvq_dim, T, device=codes.device, dtype=torch.float32)
1282
+ for i, quantizer in enumerate(self.quantizers[:nq]):
1283
+ quantizer = cast(MossAudioTokenizerLFQ, quantizer)
1284
+ emb += quantizer.decode_code(codes[i]).float()
1285
+ emb = self.output_proj(emb)
1286
+ return emb
1287
+
1288
+
1289
+ # =============================================================================
1290
+ # Main Model Classes
1291
+ # =============================================================================
1292
+
1293
+
1294
+ @auto_docstring
1295
+ class MossAudioTokenizerPreTrainedModel(PreTrainedAudioTokenizerBase):
1296
+ """Base class for MossAudioTokenizer models."""
1297
+
1298
+ config_class = MossAudioTokenizerConfig
1299
+ base_model_prefix = ""
1300
+ main_input_name = "input_values"
1301
+ input_modalities = "audio"
1302
+ supports_gradient_checkpointing = False
1303
+ _no_split_modules = [
1304
+ "MossAudioTokenizerTransformerLayer",
1305
+ "MossAudioTokenizerResidualVQ",
1306
+ "MossAudioTokenizerResidualLFQ",
1307
+ ]
1308
+
1309
+ def _init_weights(self, module: nn.Module) -> None:
1310
+ if isinstance(module, MossAudioTokenizerLayerScale):
1311
+ nn.init.constant_(module.scale, 1e-4)
1312
+
1313
+
1314
+ @auto_docstring(
1315
+ custom_intro="""
1316
+ The MossAudioTokenizer neural audio codec model for audio tokenization and synthesis.
1317
+ """
1318
+ )
1319
+ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
1320
+ """
1321
+ MossAudioTokenizer model for audio tokenization and synthesis.
1322
+
1323
+ This model can encode audio waveforms into discrete tokens and decode
1324
+ tokens back into audio waveforms.
1325
+ """
1326
+
1327
+ def __init__(self, config: MossAudioTokenizerConfig):
1328
+ super().__init__(config)
1329
+
1330
+ self.config = config
1331
+ _ = config.version
1332
+ self.sampling_rate = config.sampling_rate
1333
+ self.downsample_rate = config.downsample_rate
1334
+ self.causal_transformer_context_duration = config.causal_transformer_context_duration
1335
+
1336
+ # Build encoder
1337
+ current_frame_rate: float = float(self.sampling_rate)
1338
+ self.encoder = nn.ModuleList()
1339
+
1340
+ for encoder_kwargs_i in config.encoder_kwargs:
1341
+ encoder_kwargs_i = dict(encoder_kwargs_i) # Make a copy
1342
+ if encoder_kwargs_i["module_type"] == "PatchedPretransform":
1343
+ self.encoder.append(MossAudioTokenizerPatchedPretransform(**encoder_kwargs_i, is_downsample=True))
1344
+ elif encoder_kwargs_i["module_type"] == "Transformer":
1345
+ self.encoder.append(
1346
+ MossAudioTokenizerProjectedTransformer(
1347
+ **encoder_kwargs_i,
1348
+ context=int(current_frame_rate * self.causal_transformer_context_duration),
1349
+ )
1350
+ )
1351
+ current_frame_rate /= cast(MossAudioTokenizerPatchedPretransform, self.encoder[-1]).downsample_ratio
1352
+
1353
+ # Build quantizer
1354
+ quantizer_kwargs = dict(config.quantizer_kwargs)
1355
+ quantizer_type = quantizer_kwargs.get("quantizer_type", getattr(config, "quantizer_type", "rvq"))
1356
+ if quantizer_type in {"rvq", "spec_rvq"}:
1357
+ self.quantizer = MossAudioTokenizerResidualVQ(**quantizer_kwargs)
1358
+ elif quantizer_type in {"rlfq", "random_prefix_rlfq"}:
1359
+ self.quantizer = MossAudioTokenizerResidualLFQ(**quantizer_kwargs)
1360
+ else:
1361
+ raise ValueError(f"Unsupported quantizer_type: {quantizer_type}")
1362
+
1363
+ # Build decoder
1364
+ decoder_kwargs_list = copy.deepcopy(config.decoder_kwargs)
1365
+ self.decoder = nn.ModuleList()
1366
+
1367
+ for decoder_kwargs_i in decoder_kwargs_list:
1368
+ decoder_kwargs_i = dict(decoder_kwargs_i)
1369
+ if decoder_kwargs_i["module_type"] == "PatchedPretransform":
1370
+ self.decoder.append(MossAudioTokenizerPatchedPretransform(**decoder_kwargs_i, is_downsample=False))
1371
+ elif decoder_kwargs_i["module_type"] == "Transformer":
1372
+ self.decoder.append(
1373
+ MossAudioTokenizerProjectedTransformer(
1374
+ **decoder_kwargs_i,
1375
+ context=int(current_frame_rate * self.causal_transformer_context_duration),
1376
+ )
1377
+ )
1378
+ current_frame_rate *= cast(MossAudioTokenizerPatchedPretransform, self.decoder[-1]).downsample_ratio
1379
+
1380
+ self.post_init()
1381
+
1382
+ def _start_streaming(self, batch_size: int):
1383
+ """Start streaming mode for all modules."""
1384
+
1385
+ def _start(module):
1386
+ if isinstance(module, StreamingModule):
1387
+ module._streaming_state = module._init_streaming_state(batch_size)
1388
+
1389
+ self.apply(_start)
1390
+
1391
+ def _stop_streaming(self):
1392
+ """Stop streaming mode for all modules."""
1393
+
1394
+ def _stop(module):
1395
+ if isinstance(module, StreamingModule):
1396
+ module._streaming_state = None
1397
+
1398
+ self.apply(_stop)
1399
+
1400
+ @contextmanager
1401
+ def streaming(self, batch_size: int = 1):
1402
+ """Context manager for streaming mode."""
1403
+ self._start_streaming(batch_size)
1404
+ try:
1405
+ yield
1406
+ finally:
1407
+ self._stop_streaming()
1408
+
1409
+ @torch.no_grad()
1410
+ def batch_encode(self, wav_list: list[torch.Tensor]) -> MossAudioTokenizerEncoderOutput:
1411
+ """Batch encode a list of audio waveforms.
1412
+
1413
+ Args:
1414
+ wav_list: List of audio tensors, each of shape `(num_samples,)`.
1415
+
1416
+ Returns:
1417
+ [`MossAudioTokenizerEncoderOutput`] with `audio_codes` and `audio_codes_lengths`.
1418
+ """
1419
+ if len(wav_list) == 0:
1420
+ raise ValueError("`wav_list` must contain at least one waveform.")
1421
+
1422
+ device = wav_list[0].device
1423
+ batch_size = len(wav_list)
1424
+
1425
+ max_length = max(wav.shape[-1] for wav in wav_list)
1426
+ input_values = torch.zeros(batch_size, 1, max_length, device=device)
1427
+ input_lengths = torch.zeros(batch_size, device=device, dtype=torch.long)
1428
+
1429
+ for i, wav in enumerate(wav_list):
1430
+ input_values[i, 0, : wav.shape[-1]] = wav
1431
+ input_lengths[i] = wav.shape[-1]
1432
+
1433
+ return self._encode_frame(input_values, input_lengths)
1434
+
1435
+ @torch.no_grad()
1436
+ def batch_decode(self, codes_list: list[torch.Tensor]) -> MossAudioTokenizerDecoderOutput:
1437
+ """Batch decode a list of audio codes.
1438
+
1439
+ Args:
1440
+ codes_list: List of audio code tensors, each of shape `(num_quantizers, codes_length)`.
1441
+
1442
+ Returns:
1443
+ [`MossAudioTokenizerDecoderOutput`] with `audio` and `audio_lengths`.
1444
+ """
1445
+ if len(codes_list) == 0:
1446
+ raise ValueError("`codes_list` must contain at least one code tensor.")
1447
+
1448
+ batch_size = len(codes_list)
1449
+ device = codes_list[0].device
1450
+ num_quantizers = codes_list[0].shape[0]
1451
+ max_length = max(codes.shape[-1] for codes in codes_list)
1452
+
1453
+ audio_codes = torch.zeros(num_quantizers, batch_size, max_length, device=device, dtype=torch.long)
1454
+ audio_codes_lengths = torch.zeros(batch_size, device=device, dtype=torch.long)
1455
+
1456
+ for i, codes in enumerate(codes_list):
1457
+ audio_codes[:, i, : codes.shape[-1]] = codes
1458
+ audio_codes_lengths[i] = codes.shape[-1]
1459
+
1460
+ return self._decode_frame(audio_codes, audio_codes_lengths)
1461
+
1462
+ @torch.no_grad()
1463
+ def _encode_frame(
1464
+ self,
1465
+ input_values: torch.Tensor,
1466
+ input_lengths: torch.Tensor | None = None,
1467
+ n_quantizers: int | None = None,
1468
+ ) -> MossAudioTokenizerEncoderOutput:
1469
+ """Tokenize audio waveform into discrete tokens."""
1470
+ # Handle input shape
1471
+ if input_values.dim() == 2:
1472
+ input_values = input_values.unsqueeze(1)
1473
+
1474
+ B, _, T = input_values.shape
1475
+ device = input_values.device
1476
+
1477
+ if input_lengths is None:
1478
+ input_lengths = torch.full((B,), T, device=device, dtype=torch.long)
1479
+
1480
+ # Pad to multiple of downsample_rate
1481
+ if T % self.downsample_rate != 0:
1482
+ pad_length = self.downsample_rate - (T % self.downsample_rate)
1483
+ input_values = F.pad(input_values, (0, pad_length))
1484
+
1485
+ # Encode
1486
+ e, e_lengths = input_values, input_lengths
1487
+ for encoder_module in self.encoder:
1488
+ e, e_lengths = encoder_module(e, e_lengths)
1489
+
1490
+ # Quantize
1491
+ quantizer = cast(MossAudioTokenizerResidualVQ | MossAudioTokenizerResidualLFQ, self.quantizer)
1492
+ zq, audio_codes, audio_codes_lengths = quantizer(e, e_lengths, n_quantizers)
1493
+
1494
+ return MossAudioTokenizerEncoderOutput(
1495
+ audio_codes=audio_codes, audio_codes_lengths=audio_codes_lengths, encoder_hidden_states=e
1496
+ )
1497
+
1498
+ @torch.no_grad()
1499
+ def _decode_frame(
1500
+ self,
1501
+ codes: torch.Tensor,
1502
+ codes_lengths: torch.Tensor | None = None,
1503
+ ) -> MossAudioTokenizerDecoderOutput:
1504
+ """Detokenize discrete tokens into audio waveform."""
1505
+ nq, B, T = codes.shape
1506
+ device = codes.device
1507
+
1508
+ if codes_lengths is None:
1509
+ codes_lengths = torch.full((B,), T, device=device, dtype=torch.long)
1510
+
1511
+ # Decode from codes
1512
+ quantizer = cast(MossAudioTokenizerResidualVQ | MossAudioTokenizerResidualLFQ, self.quantizer)
1513
+ zq = quantizer.decode_codes(codes)
1514
+
1515
+ d, d_lengths = zq, codes_lengths
1516
+ for decoder_module in self.decoder:
1517
+ d, d_lengths = decoder_module(d, d_lengths)
1518
+
1519
+ return MossAudioTokenizerDecoderOutput(audio=d, audio_lengths=d_lengths)
1520
+
1521
+ def encode( # type: ignore[override]
1522
+ self,
1523
+ input_values: torch.Tensor,
1524
+ padding_mask: torch.Tensor | None = None,
1525
+ num_quantizers: int | None = None,
1526
+ return_dict: bool | None = None,
1527
+ chunk_duration: float | None = None,
1528
+ ):
1529
+ """
1530
+ Encodes the input audio waveform into discrete codes.
1531
+
1532
+ Args:
1533
+ input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
1534
+ Float values of the input audio waveform.
1535
+ padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1536
+ Mask to indicate valid audio samples.
1537
+ num_quantizers (`int`, *optional*):
1538
+ Number of quantizers to use. By default, all quantizers are used.
1539
+ return_dict (`bool`, *optional*):
1540
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1541
+ chunk_duration (`float`, *optional*):
1542
+ If provided, encode the input waveform in successive chunks of `chunk_duration` seconds while keeping a
1543
+ streaming KV cache for the causal transformers.
1544
+
1545
+ `chunk_duration` must be <= `config.causal_transformer_context_duration`, and
1546
+ `chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`.
1547
+
1548
+ Returns:
1549
+ `MossAudioTokenizerEncoderOutput` or tuple containing audio codes and lengths.
1550
+ """
1551
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1552
+
1553
+ # Handle input shape
1554
+ if input_values.dim() == 2:
1555
+ input_values = input_values.unsqueeze(1)
1556
+
1557
+ B, _, T = input_values.shape
1558
+ device = input_values.device
1559
+
1560
+ if padding_mask is not None:
1561
+ input_lengths = padding_mask.sum(dim=-1).long()
1562
+ else:
1563
+ input_lengths = torch.full((B,), T, device=device, dtype=torch.long)
1564
+
1565
+ if chunk_duration is None:
1566
+ encoder_output = self._encode_frame(input_values, input_lengths, num_quantizers)
1567
+ else:
1568
+ if chunk_duration <= 0:
1569
+ raise ValueError("`chunk_duration` must be > 0 when provided.")
1570
+ if chunk_duration > self.causal_transformer_context_duration:
1571
+ raise ValueError(
1572
+ "`chunk_duration` must be <= `config.causal_transformer_context_duration` "
1573
+ f"({self.causal_transformer_context_duration}), got {chunk_duration}."
1574
+ )
1575
+ if B != 1:
1576
+ raise ValueError("Streaming encode via `chunk_duration` currently only supports batch_size=1.")
1577
+
1578
+ chunk_length = int(round(chunk_duration * self.sampling_rate))
1579
+ if chunk_length <= 0:
1580
+ raise ValueError("`chunk_duration` is too small and results in chunk_length <= 0.")
1581
+ if chunk_length % self.downsample_rate != 0:
1582
+ raise ValueError(
1583
+ "`chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`. "
1584
+ f"Got chunk_length={chunk_length}, downsample_rate={self.downsample_rate}."
1585
+ )
1586
+
1587
+ input_length = int(input_lengths[0].item())
1588
+ if input_length <= chunk_length:
1589
+ encoder_output = self._encode_frame(input_values[..., :input_length], input_lengths, num_quantizers)
1590
+ else:
1591
+ codes_chunks: list[torch.Tensor] = []
1592
+ hidden_chunks: list[torch.Tensor] = []
1593
+
1594
+ with ExitStack() as exit_stack:
1595
+ for encoder_module in self.encoder:
1596
+ if isinstance(encoder_module, StreamingModule):
1597
+ exit_stack.enter_context(encoder_module.streaming(batch_size=B))
1598
+
1599
+ for start_idx in range(0, input_length, chunk_length):
1600
+ input_length_i = min(chunk_length, input_length - start_idx)
1601
+ if input_length_i <= 0:
1602
+ break
1603
+
1604
+ input_lengths_i = torch.tensor([input_length_i], device=device, dtype=torch.long)
1605
+ input_values_i = input_values[..., start_idx : start_idx + input_length_i]
1606
+ result_i = self._encode_frame(input_values_i, input_lengths_i, num_quantizers)
1607
+
1608
+ if result_i.audio_codes is None or result_i.audio_codes_lengths is None:
1609
+ raise RuntimeError("Internal error: `_encode_frame` returned empty audio codes.")
1610
+ if result_i.encoder_hidden_states is None:
1611
+ raise RuntimeError("Internal error: `_encode_frame` returned empty encoder hidden states.")
1612
+
1613
+ codes_length_i = result_i.audio_codes_lengths
1614
+ codes_chunks.append(result_i.audio_codes[:, :, : codes_length_i[0]])
1615
+ hidden_chunks.append(result_i.encoder_hidden_states[:, :, : codes_length_i[0]])
1616
+
1617
+ audio_codes = torch.cat(codes_chunks, dim=-1)
1618
+ encoder_hidden_states = torch.cat(hidden_chunks, dim=-1)
1619
+ audio_codes_lengths = torch.tensor([audio_codes.shape[-1]], device=device, dtype=torch.long)
1620
+ encoder_output = MossAudioTokenizerEncoderOutput(
1621
+ audio_codes=audio_codes,
1622
+ audio_codes_lengths=audio_codes_lengths,
1623
+ encoder_hidden_states=encoder_hidden_states,
1624
+ )
1625
+
1626
+ if not return_dict:
1627
+ assert encoder_output.audio_codes is not None
1628
+ assert encoder_output.audio_codes_lengths is not None
1629
+ return (
1630
+ cast(torch.Tensor, encoder_output.audio_codes),
1631
+ cast(torch.Tensor, encoder_output.audio_codes_lengths),
1632
+ )
1633
+ return encoder_output
1634
+
1635
+ def decode( # type: ignore[override]
1636
+ self,
1637
+ audio_codes: torch.Tensor,
1638
+ padding_mask: torch.Tensor | None = None,
1639
+ return_dict: bool | None = None,
1640
+ chunk_duration: float | None = None,
1641
+ ):
1642
+ """
1643
+ Decodes the given codes into an output audio waveform.
1644
+
1645
+ Args:
1646
+ audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`):
1647
+ Discrete code embeddings computed using `model.encode`.
1648
+ padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1649
+ Mask to indicate valid code positions.
1650
+ return_dict (`bool`, *optional*):
1651
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1652
+ chunk_duration (`float`, *optional*):
1653
+ If provided, decode the input codes in successive chunks of `chunk_duration` seconds while keeping a
1654
+ streaming KV cache for the causal transformers.
1655
+
1656
+ `chunk_duration` must be <= `config.causal_transformer_context_duration`, and
1657
+ `chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`.
1658
+
1659
+ Returns:
1660
+ `MossAudioTokenizerDecoderOutput` or tuple containing decoded audio.
1661
+ """
1662
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1663
+
1664
+ if audio_codes.dim() == 2:
1665
+ audio_codes = audio_codes.unsqueeze(1) # nq, T -> nq, B=1, T
1666
+
1667
+ _, B, T = audio_codes.shape
1668
+ device = audio_codes.device
1669
+
1670
+ if padding_mask is not None:
1671
+ codes_lengths = padding_mask.sum(dim=-1).long()
1672
+ else:
1673
+ codes_lengths = torch.full((B,), T, device=device, dtype=torch.long)
1674
+
1675
+ if chunk_duration is None:
1676
+ decoder_output = self._decode_frame(audio_codes, codes_lengths)
1677
+ else:
1678
+ if chunk_duration <= 0:
1679
+ raise ValueError("`chunk_duration` must be > 0 when provided.")
1680
+ if chunk_duration > self.causal_transformer_context_duration:
1681
+ raise ValueError(
1682
+ "`chunk_duration` must be <= `config.causal_transformer_context_duration` "
1683
+ f"({self.causal_transformer_context_duration}), got {chunk_duration}."
1684
+ )
1685
+ if B != 1:
1686
+ raise ValueError("Streaming decode via `chunk_duration` currently only supports batch_size=1.")
1687
+
1688
+ chunk_length = int(round(chunk_duration * self.sampling_rate))
1689
+ if chunk_length <= 0:
1690
+ raise ValueError("`chunk_duration` is too small and results in chunk_length <= 0.")
1691
+ if chunk_length % self.downsample_rate != 0:
1692
+ raise ValueError(
1693
+ "`chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`. "
1694
+ f"Got chunk_length={chunk_length}, downsample_rate={self.downsample_rate}."
1695
+ )
1696
+
1697
+ chunk_frame_length = chunk_length // self.downsample_rate
1698
+ codes_length = int(codes_lengths[0].item())
1699
+ if codes_length <= chunk_frame_length:
1700
+ decoder_output = self._decode_frame(audio_codes[..., :codes_length], codes_lengths)
1701
+ else:
1702
+ wav_chunks: list[torch.Tensor] = []
1703
+ with ExitStack() as exit_stack:
1704
+ for decoder_module in self.decoder:
1705
+ if isinstance(decoder_module, StreamingModule):
1706
+ exit_stack.enter_context(decoder_module.streaming(batch_size=B))
1707
+
1708
+ for start_idx in range(0, codes_length, chunk_frame_length):
1709
+ codes_length_i = min(chunk_frame_length, codes_length - start_idx)
1710
+ if codes_length_i <= 0:
1711
+ break
1712
+
1713
+ codes_lengths_i = torch.tensor([codes_length_i], device=device, dtype=torch.long)
1714
+ codes_i = audio_codes[:, :, start_idx : start_idx + codes_length_i]
1715
+ result_i = self._decode_frame(codes_i, codes_lengths_i)
1716
+
1717
+ if result_i.audio is None or result_i.audio_lengths is None:
1718
+ raise RuntimeError("Internal error: `_decode_frame` returned empty audio.")
1719
+
1720
+ wav_chunks.append(result_i.audio[:, :, : result_i.audio_lengths[0]])
1721
+
1722
+ wav = torch.cat(wav_chunks, dim=-1)
1723
+ audio_lengths = torch.tensor([wav.shape[-1]], device=device, dtype=torch.long)
1724
+ decoder_output = MossAudioTokenizerDecoderOutput(audio=wav, audio_lengths=audio_lengths)
1725
+
1726
+ if not return_dict:
1727
+ assert decoder_output.audio is not None
1728
+ return (cast(torch.Tensor, decoder_output.audio),)
1729
+ return decoder_output
1730
+
1731
+ @auto_docstring
1732
+ def forward(
1733
+ self,
1734
+ input_values: torch.FloatTensor | None = None,
1735
+ padding_mask: torch.BoolTensor | None = None,
1736
+ audio_codes: torch.Tensor | None = None,
1737
+ num_quantizers: int | None = None,
1738
+ return_dict: bool | None = None,
1739
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | MossAudioTokenizerOutput: # type: ignore[override]
1740
+ r"""
1741
+ input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
1742
+ Raw audio input converted to Float.
1743
+ padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
1744
+ Mask to avoid computing on padding token indices. Mask values selected in `[0, 1]`:
1745
+ - 1 for tokens that are **not masked**,
1746
+ - 0 for tokens that are **masked**.
1747
+ audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`, *optional*):
1748
+ Discrete code embeddings computed using `model.encode`.
1749
+ num_quantizers (`int`, *optional*):
1750
+ Number of quantizers (codebooks) to use. By default, all quantizers are used.
1751
+ return_dict (`bool`, *optional*):
1752
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1753
+
1754
+ Examples:
1755
+
1756
+ ```python
1757
+ >>> import torch
1758
+ >>> from transformers import MossAudioTokenizerModel
1759
+
1760
+ >>> model = MossAudioTokenizerModel.from_pretrained("moss_audio_tokenizer-model")
1761
+
1762
+ >>> # Create dummy audio input
1763
+ >>> audio = torch.randn(1, 1, 24000) # 1 second of audio at 24kHz
1764
+
1765
+ >>> outputs = model(input_values=audio)
1766
+ >>> audio_codes = outputs.audio_codes
1767
+ >>> audio_values = outputs.audio
1768
+ ```
1769
+ """
1770
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1771
+
1772
+ output_audio_codes: torch.Tensor | None = None
1773
+ output_audio_codes_lengths: torch.Tensor | None = None
1774
+ output_audio: torch.Tensor | None = None
1775
+ output_audio_lengths: torch.Tensor | None = None
1776
+ decoded_from_encoded_codes = False
1777
+
1778
+ # Encode if input_values provided
1779
+ if input_values is not None:
1780
+ encoder_output = self.encode(input_values, padding_mask, num_quantizers, return_dict=True)
1781
+ encoder_output = cast(MossAudioTokenizerEncoderOutput, encoder_output)
1782
+ output_audio_codes = encoder_output.audio_codes
1783
+ output_audio_codes_lengths = encoder_output.audio_codes_lengths
1784
+
1785
+ # If codes not provided separately, use encoded codes for decoding
1786
+ if audio_codes is None:
1787
+ audio_codes = output_audio_codes
1788
+ decoded_from_encoded_codes = True
1789
+
1790
+ # Decode if codes available
1791
+ if audio_codes is not None:
1792
+ # If we're decoding the codes we just produced, use the computed lengths so we don't decode padded garbage.
1793
+ if decoded_from_encoded_codes and output_audio_codes_lengths is not None:
1794
+ decoder_output = self._decode_frame(audio_codes, output_audio_codes_lengths)
1795
+ else:
1796
+ decoder_output = self.decode(audio_codes, padding_mask=padding_mask, return_dict=True)
1797
+ decoder_output = cast(MossAudioTokenizerDecoderOutput, decoder_output)
1798
+ output_audio = decoder_output.audio
1799
+ output_audio_lengths = decoder_output.audio_lengths
1800
+
1801
+ if not return_dict:
1802
+ return (output_audio_codes, output_audio, output_audio_lengths)
1803
+
1804
+ return MossAudioTokenizerOutput(
1805
+ audio=output_audio,
1806
+ audio_lengths=output_audio_lengths,
1807
+ audio_codes=output_audio_codes,
1808
+ audio_codes_lengths=output_audio_codes_lengths,
1809
+ )
1810
+
1811
+
1812
+ __all__ = ["MossAudioTokenizerModel", "MossAudioTokenizerPreTrainedModel"]