File size: 8,405 Bytes
61f845e d4a3b2c 61f845e d4a3b2c 737e5b1 d4a3b2c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 | ---
license: apache-2.0
library_name: transformers
tags:
- audio
- audio-tokenizer
- neural-codec
- moss-tts-family
- MOSS Audio Tokenizer
- speech-tokenizer
- trust-remote-code
---
# MossAudioTokenizer
This is the code for MOSS-Audio-Tokenizer presented in [MOSS-Audio-Tokenizer: Scaling Audio Tokenizers for Future Audio Foundation Models](https://arxiv.org/abs/2602.10934).
**MOSSAudioTokenizer** is a unified discrete audio tokenizer based on the **Cat** (**C**ausal **A**udio **T**okenizer with **T**ransformer) architecture. Scaling to 1.6 billion parameters, it functions as a unified discrete interface, delivering both lossless-quality reconstruction and high-level semantic alignment.
**Key Features:**
* **Extreme Compression & Variable Bitrate**: It compresses 48kHz stereo audio into a remarkably low frame rate of 12.5Hz. Utilizing a 32-layer Residual LFQ quantizer stack, it supports high-fidelity reconstruction across a wide range of bitrates.
* **Pure Transformer Architecture**: The model features a "CNN-free" homogeneous architecture built entirely from Causal Transformer blocks. With 1.6B combined parameters (Encoder + Decoder), it ensures exceptional scalability and supports low-latency streaming inference.
* **Large-Scale General Audio Training**: Trained on 3 million hours of diverse audio data, the model excels at encoding and reconstructing all audio domains, including speech, sound effects, and music.
* **Unified Semantic-Acoustic Representation**: While achieving state-of-the-art reconstruction quality, Cat produces discrete tokens that are "semantic-rich," making them ideal for downstream tasks like speech understanding (ASR) and generation (TTS).
* **Fully Trained From Scratch**: Cat does not rely on any pretrained encoders (such as HuBERT or Whisper) or distillation from teacher models. All representations are learned autonomously from raw data.
* **End-to-End Joint Optimization**: All components—including the encoder, quantizer, decoder, discriminator, and a decoder-only LLM for semantic alignment—are optimized jointly in a single unified training pipeline.
**Summary:**
By combining a simple, scalable architecture with massive-scale data, the Cat architecture overcomes the bottlenecks of traditional audio tokenizers. It provides a robust, high-fidelity, and semantically grounded interface for the next generation of native audio foundation models.
This repository contains a lightweight remote-code implementation that mirrors the current 🤗 Transformers
`transformers.models.moss_audio_tokenizer` module. It is intended to be uploaded to a Hugging Face Hub model repository
and loaded with `trust_remote_code=True` when needed.
## Usage
### Quickstart
```python
import torch
from transformers import AutoModel
import torchaudio
repo_id = "OpenMOSS-Team/MOSS-Audio-Tokenizer"
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True).eval()
wav, sr = torchaudio.load('demo/demo_gt.wav')
if sr != model.sampling_rate:
wav = torchaudio.functional.resample(wav, sr, model.sampling_rate)
if wav.shape[0] == 1:
wav = wav.repeat(model.config.number_channels, 1)
else:
wav = wav[: model.config.number_channels]
wav = wav.unsqueeze(0)
enc = model.encode(wav, return_dict=True)
print(f"enc.audio_codes.shape: {enc.audio_codes.shape}")
dec = model.decode(enc.audio_codes, return_dict=True)
print(f"dec.audio.shape: {dec.audio.shape}")
wav = dec.audio.squeeze(0)
torchaudio.save("demo/demo_rec.wav", wav, sample_rate=model.sampling_rate)
# Decode using only the first 8 layers of the RVQ
dec_rvq8 = model.decode(enc.audio_codes[:8], return_dict=True)
wav_rvq8 = dec_rvq8.audio.squeeze(0)
torchaudio.save("demo/demo_rec_rvq8.wav", wav_rvq8, sample_rate=model.sampling_rate)
```
### Attention Backend And Compute Dtype
`config.attention_implementation` controls whether transformer layers prefer `sdpa` or `flash_attention_2`.
`config.compute_dtype` controls the non-quantizer autocast dtype and supports `fp32`, `bf16`, and `fp16`.
```python
model.set_attention_implementation("flash_attention_2")
model.set_compute_dtype("fp16")
```
The quantizer always runs in fp32.
### Streaming
`MossAudioTokenizerModel.encode`, `decode`, `batch_encode`, and `batch_decode` all support streaming through a
`chunk_duration` argument.
- `chunk_duration` is expressed in seconds.
- `chunk_duration * MossAudioTokenizerConfig.sampling_rate` must be divisible by `MossAudioTokenizerConfig.downsample_rate`.
- Streaming batch inference is supported.
- The public waveform interface expects stereo inputs shaped `(2, T)` or batched stereo inputs shaped `(B, 2, T)`.
```python
import torch
from transformers import AutoModel
repo_id = "OpenMOSS-Team/MOSS-Audio-Tokenizer"
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True).eval()
audio = torch.randn(2, 48000 * 6) # dummy stereo waveform
# 6.0s @ 48kHz = 288000 samples, divisible by downsample_rate=3840
enc = model.encode(audio.unsqueeze(0), return_dict=True, chunk_duration=0.08)
dec = model.decode(enc.audio_codes, return_dict=True, chunk_duration=0.08)
batch_enc = model.batch_encode([audio, audio[:, : 48000 * 3]], chunk_duration=0.08)
codes_list = [
batch_enc.audio_codes[:, i, : batch_enc.audio_codes_lengths[i]]
for i in range(batch_enc.audio_codes.shape[1])
]
batch_dec = model.batch_decode(codes_list, chunk_duration=0.08)
```
#### Continuous Batch Streaming Decode
For decoder-side continuous batching, prefer `batch_decode(..., streaming=True, ...)`.
- The first streaming call may pass `max_batch_size=...`. If it is omitted, the first batch size reserves the
fixed-slot decoder budget for that public stream.
- Same-size calls continue the existing logical rows in-order.
- If a later call is larger, the new rows are admitted by tail append.
- `finalize_indices` means "decode these rows one last time, then evict them". The indices are interpreted against the
pre-call logical order.
- After a finalize call returns, the next streaming call may use the smaller survivor batch.
- `reset_stream=True` discards the hidden public streaming state and starts a fresh stream.
Milestone 1 boundaries:
- decode-only continuous batching
- one active streaming decode state per model instance
- fixed-slot decoder reservation from `max_batch_size`
- no encode-side continuous batching
- no physical compaction of surviving decode slots
- no multi-session concurrency on one model instance
```python
import torch
from transformers import AutoModel
repo_id = "OpenMOSS-Team/MOSS-Audio-Tokenizer"
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True).eval()
num_quantizers = model.config.quantizer_kwargs["num_quantizers"]
codes_a0 = torch.randint(0, 8, (num_quantizers, 2))
codes_b0 = torch.randint(0, 8, (num_quantizers, 3))
codes_a1 = torch.randint(0, 8, (num_quantizers, 2))
codes_b1 = torch.randint(0, 8, (num_quantizers, 2))
codes_c0 = torch.randint(0, 8, (num_quantizers, 1))
codes_a2 = torch.randint(0, 8, (num_quantizers, 1))
codes_b2 = torch.randint(0, 8, (num_quantizers, 2))
codes_c1 = torch.randint(0, 8, (num_quantizers, 2))
codes_b3 = torch.randint(0, 8, (num_quantizers, 1))
codes_c2 = torch.randint(0, 8, (num_quantizers, 1))
# First call reserves 3 fixed decoder slots for A and B.
out_ab0 = model.batch_decode(
[codes_a0, codes_b0],
streaming=True,
max_batch_size=3,
reset_stream=True,
)
# Same logical rows continue in-order; C is a tail append.
out_abc1 = model.batch_decode(
[codes_a1, codes_b1, codes_c0],
streaming=True,
)
# Finalize A against the pre-call logical order. A still decodes in this call,
# then is evicted immediately afterward.
out_abc2 = model.batch_decode(
[codes_a2, codes_b2, codes_c1],
streaming=True,
finalize_indices=[0],
)
# The next call can shrink to the surviving logical rows only.
out_bc3 = model.batch_decode(
[codes_b3, codes_c2],
streaming=True,
)
```
## Repository layout
- `configuration_moss_audio_tokenizer.py`
- `modeling_moss_audio_tokenizer.py`
- `__init__.py`
- `config.json`
- model weights
## Citation
If you use this code or result in your paper, please cite our work as:
```tex
```
|