MCplayer commited on
Commit
0b4c806
·
1 Parent(s): 4b0005e

pre-release version

Browse files
README.md CHANGED
@@ -2,25 +2,68 @@
2
  license: apache-2.0
3
  ---
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  ```python
7
  import torchaudio
8
  from transformers import AutoFeatureExtractor, AutoModel
9
 
10
- wav_form, sampling_rate = torchaudio.load("examples/zh_spk1_moon.wav")
11
  feature_extractor = AutoFeatureExtractor.from_pretrained("MCplayer/XY_Tokenizer", trust_remote_code=True)
12
  codec = AutoModel.from_pretrained("MCplayer/XY_Tokenizer", trust_remote_code=True, device_map="auto").eval()
13
 
 
 
 
14
  if sampling_rate != 16000:
15
- resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
16
- wav_form = resampler(wav_form)
17
 
 
18
  input_spectrum = feature_extractor(wav_form, sampling_rate=16000, return_attention_mask=True, return_tensors="pt")
 
19
  code = codec.encode(input_spectrum)
20
 
 
 
21
  output_wav = codec.decode(code["audio_codes"], overlap_seconds=10)
22
- for i, audio in enumerate(output_wav["audio_values"]):
23
- torchaudio.save(f"outputs/audio{i}.wav", audio.cpu(), 24000)
24
 
 
 
 
25
 
26
- ```
 
2
  license: apache-2.0
3
  ---
4
 
5
+ # **Introduction**
6
+
7
+ **`XY-Tokenizer`** is a speech codec that simultaneously models both semantic and acoustic aspects of speech, converting audio into discrete tokens and decoding them back to high-quality audio. It achieves efficient speech representation at only 1kbps with RVQ8 quantization at 12.5Hz frame rate.
8
+
9
+ - **Paper:** [Read on arXiv](https://arxiv.org/pdf/2506.23325)
10
+ - **Source Code:**
11
+ - [GitHub Repo](https://github.com/OpenMOSS/MOSS-TTSD/tree/main/XY_Tokenizer)
12
+ - [Hugging Face Repo](https://huggingface.co/spaces/fnlp/MOSS-TTSD/tree/main/XY_Tokenizer)
13
+
14
+ ## 📚 Related Project: **`MOSS-TTSD`**
15
+
16
+ **`XY-Tokenizer`** serves as the underlying neural codec for **`MOSS-TTSD`**, our 1.7B Audio Language Model. \
17
+ Explore **`MOSS-TTSD`** for advanced text-to-speech and other audio generation tasks on [GitHub](https://github.com/OpenMOSS/MOSS-TTSD), [Blog](http://www.open-moss.com/en/moss-ttsd/), [博客](https://www.open-moss.com/cn/moss-ttsd/), and [Space Demo](https://huggingface.co/spaces/fnlp/MOSS-TTSD).
18
+
19
+ ## ✨ Features
20
+
21
+ - **Dual-channel modeling**: Simultaneously captures semantic meaning and acoustic details
22
+ - **Efficient representation**: 1kbps bitrate with RVQ8 quantization at 12.5Hz
23
+ - **High-quality audio tokenization**: Convert speech to discrete tokens and back with minimal quality loss
24
+ - **Long audio support**: Process audio files longer than 30 seconds using chunking with overlap
25
+ - **Batch processing**: Efficiently process multiple audio files in batches
26
+ - **24kHz output**: Generate high-quality 24kHz audio output
27
+
28
+
29
+ ## 🚀 Installation
30
+
31
+ ```bash
32
+ git clone https://github.com/OpenMOSS/MOSS-TTSD.git
33
+ cd MOSS-TTSD
34
+ conda create -n xy_tokenizer python=3.10 -y && conda activate xy_tokenizer
35
+ pip install -r XY_Tokenizer/requirements.txt
36
+ ```
37
+
38
+ ## 💻 Quick Start
39
+
40
+ Here's how to use **`XY-Tokenizer`** with `transformers` to encode an audio file into discrete tokens and decode it back into a waveform.
41
 
42
  ```python
43
  import torchaudio
44
  from transformers import AutoFeatureExtractor, AutoModel
45
 
46
+ # 1. Load the feature extractor and the codec model
47
  feature_extractor = AutoFeatureExtractor.from_pretrained("MCplayer/XY_Tokenizer", trust_remote_code=True)
48
  codec = AutoModel.from_pretrained("MCplayer/XY_Tokenizer", trust_remote_code=True, device_map="auto").eval()
49
 
50
+ # 2. Load and preprocess the audio
51
+ # The model expects a 16kHz sample rate.
52
+ wav_form, sampling_rate = torchaudio.load("examples/zh_spk1_moon.wav")
53
  if sampling_rate != 16000:
54
+ wav_form = torchaudio.functional.resample(wav_form, orig_freq=sampling_rate, new_freq=16000)
 
55
 
56
+ # 3. Encode the audio into discrete codes
57
  input_spectrum = feature_extractor(wav_form, sampling_rate=16000, return_attention_mask=True, return_tensors="pt")
58
+ # The 'code' dictionary contains the discrete audio codes
59
  code = codec.encode(input_spectrum)
60
 
61
+ # 4. Decode the codes back to an audio waveform
62
+ # The output is high-quality 24kHz audio.
63
  output_wav = codec.decode(code["audio_codes"], overlap_seconds=10)
 
 
64
 
65
+ # 5. Save the reconstructed audio
66
+ for i, audio in enumerate(output_wav["audio_values"]):
67
+ torchaudio.save(f"outputs/audio_{i}.wav", audio.cpu(), 24000)
68
 
69
+ ```
config.json CHANGED
@@ -21,7 +21,6 @@
21
  "padding_side": "right",
22
  "padding_value": 0.0,
23
  "sampling_rate": 16000,
24
- "encoder_downsample_rate": 1280,
25
  "return_attention_mask": true,
26
  "return_tensors": "pt"
27
  },
@@ -120,5 +119,7 @@
120
  "hop_size": 240,
121
  "padding": "same"
122
  }
123
- }
 
 
124
  }
 
21
  "padding_side": "right",
22
  "padding_value": 0.0,
23
  "sampling_rate": 16000,
 
24
  "return_attention_mask": true,
25
  "return_tensors": "pt"
26
  },
 
119
  "hop_size": 240,
120
  "padding": "same"
121
  }
122
+ },
123
+ "torch_dtype": "float32",
124
+ "transformers_version": "4.51.0"
125
  }
feature_extraction_xy_tokenizer.py CHANGED
@@ -15,6 +15,7 @@
15
  """
16
  Feature extractor class for Whisper
17
  """
 
18
  from functools import partial
19
  from typing import List, Optional, Union
20
 
@@ -37,7 +38,6 @@ class ExtractorIterator:
37
  chunk_length=30,
38
  overlap_seconds=10,
39
  sampling_rate=16000,
40
- encoder_downsample_rate=1280,
41
  encode_func = None,
42
  ) -> None:
43
  self.data = data
@@ -45,12 +45,11 @@ class ExtractorIterator:
45
  self.chunk_length = chunk_length
46
  self.overlap_seconds = overlap_seconds
47
  self.sampling_rate = sampling_rate
48
- self.encoder_downsample_rate = encoder_downsample_rate
49
 
50
  # duration_size 是每次处理的有效音频长度
 
51
  self.duration_seconds = self.chunk_length - self.overlap_seconds
52
  self.duration_size = int(self.duration_seconds * self.sampling_rate)
53
- self.code_duration_length = self.duration_size // self.encoder_downsample_rate
54
  # 注意:这里我们只处理不带重叠的块,重叠将在外部处理(如果需要)
55
  # 或者在迭代器内部更明确地处理。为了简化,我们假设分块是基于 duration_size
56
 
@@ -66,26 +65,30 @@ class ExtractorIterator:
66
  batch_num = 0
67
 
68
  # 注意:chunk_and_pad_view 输出的块大小是 duration_size
69
- wav_tensor = torch.zeros(self.batch_size, 1, self.duration_size)
70
  input_lengths = torch.zeros(self.batch_size, dtype=torch.long)
71
  input_seq_no = torch.zeros(self.batch_size, dtype=torch.long)
72
 
73
- def chunk_and_pad_view(tensor, chunk_size, seq_no):
74
  x = tensor[0:1, :].unsqueeze(0)
 
 
 
75
  B, C, L = x.shape
76
- num_chunks = (L + chunk_size - 1) // chunk_size
77
- target_len = num_chunks * chunk_size
78
- pad_len = target_len - L
79
- padded_x = F.pad(x, (0, pad_len))
80
- output_tensor = padded_x.view(B, num_chunks, chunk_size).transpose(0, 1)
81
- output_lengths = torch.full((num_chunks,), chunk_size, dtype=torch.long)
82
- if pad_len > 0:
83
- output_lengths[-1] = chunk_size - pad_len
 
84
  output_seq_no = torch.full((num_chunks,), seq_no, dtype=torch.long)
85
  return output_tensor, output_lengths, output_seq_no
86
 
87
  for i, sample in enumerate(self.data):
88
- sample_chunks, sample_lengths, sample_seq_no = chunk_and_pad_view(sample, self.duration_size, i)
89
 
90
  processed_in_sample = 0
91
  while processed_in_sample < len(sample_chunks):
@@ -115,7 +118,6 @@ class ExtractorIterator:
115
  ]
116
  yield BatchFeature({
117
  **self.encode_func(list_x),
118
- "input_lengths": input_lengths.clone(),
119
  "chunk_seq_no": input_seq_no.clone(),
120
  })
121
 
@@ -133,7 +135,6 @@ class ExtractorIterator:
133
  ]
134
  yield BatchFeature({
135
  **self.encode_func(list_x),
136
- "input_lengths": input_lengths.clone(),
137
  "chunk_seq_no": input_seq_no[:batch_num].clone(),
138
  })
139
 
@@ -143,7 +144,6 @@ class XYTokenizerFeatureExtractor(WhisperFeatureExtractor):
143
  self,
144
  feature_size=80,
145
  sampling_rate=16000,
146
- encoder_downsample_rate=1280,
147
  hop_length=160,
148
  chunk_length=30,
149
  n_fft=400,
@@ -166,7 +166,6 @@ class XYTokenizerFeatureExtractor(WhisperFeatureExtractor):
166
  **kwargs,
167
  )
168
  self.max_frequency = max_frequency if max_frequency is not None else sampling_rate / 2
169
- self.encoder_downsample_rate = encoder_downsample_rate
170
  self.batch_size = batch_size
171
  self.mel_filters = mel_filter_bank(
172
  num_frequency_bins=1 + n_fft // 2,
@@ -204,7 +203,6 @@ class XYTokenizerFeatureExtractor(WhisperFeatureExtractor):
204
  chunk_length=self.chunk_length,
205
  overlap_seconds=overlap_seconds,
206
  sampling_rate=self.sampling_rate,
207
- encoder_downsample_rate=self.encoder_downsample_rate,
208
  encode_func=partial(
209
  super().__call__,
210
  truncation=truncation,
 
15
  """
16
  Feature extractor class for Whisper
17
  """
18
+ import math
19
  from functools import partial
20
  from typing import List, Optional, Union
21
 
 
38
  chunk_length=30,
39
  overlap_seconds=10,
40
  sampling_rate=16000,
 
41
  encode_func = None,
42
  ) -> None:
43
  self.data = data
 
45
  self.chunk_length = chunk_length
46
  self.overlap_seconds = overlap_seconds
47
  self.sampling_rate = sampling_rate
 
48
 
49
  # duration_size 是每次处理的有效音频长度
50
+ self.chunk_size = int(self.chunk_length * self.sampling_rate)
51
  self.duration_seconds = self.chunk_length - self.overlap_seconds
52
  self.duration_size = int(self.duration_seconds * self.sampling_rate)
 
53
  # 注意:这里我们只处理不带重叠的块,重叠将在外部处理(如果需要)
54
  # 或者在迭代器内部更明确地处理。为了简化,我们假设分块是基于 duration_size
55
 
 
65
  batch_num = 0
66
 
67
  # 注意:chunk_and_pad_view 输出的块大小是 duration_size
68
+ wav_tensor = torch.zeros(self.batch_size, 1, self.chunk_size)
69
  input_lengths = torch.zeros(self.batch_size, dtype=torch.long)
70
  input_seq_no = torch.zeros(self.batch_size, dtype=torch.long)
71
 
72
+ def chunk_and_pad_view(tensor, seq_no):
73
  x = tensor[0:1, :].unsqueeze(0)
74
+
75
+ stride = self.duration_size
76
+ kernel = self.chunk_size
77
  B, C, L = x.shape
78
+
79
+ num_chunks = math.ceil(L / stride)
80
+ target_len = (num_chunks - 1) * stride + kernel
81
+ padding_size = max(0, target_len - L)
82
+ x_padded = F.pad(x, (0, padding_size), "constant", 0)
83
+ output_tensor = x_padded.unfold(dimension=2, size=kernel, step=stride).squeeze(0).transpose(0, 1)
84
+ output_lengths = torch.full((num_chunks,), kernel, dtype=torch.long)
85
+ if padding_size > 0:
86
+ output_lengths[-1] = kernel - padding_size
87
  output_seq_no = torch.full((num_chunks,), seq_no, dtype=torch.long)
88
  return output_tensor, output_lengths, output_seq_no
89
 
90
  for i, sample in enumerate(self.data):
91
+ sample_chunks, sample_lengths, sample_seq_no = chunk_and_pad_view(sample, i)
92
 
93
  processed_in_sample = 0
94
  while processed_in_sample < len(sample_chunks):
 
118
  ]
119
  yield BatchFeature({
120
  **self.encode_func(list_x),
 
121
  "chunk_seq_no": input_seq_no.clone(),
122
  })
123
 
 
135
  ]
136
  yield BatchFeature({
137
  **self.encode_func(list_x),
 
138
  "chunk_seq_no": input_seq_no[:batch_num].clone(),
139
  })
140
 
 
144
  self,
145
  feature_size=80,
146
  sampling_rate=16000,
 
147
  hop_length=160,
148
  chunk_length=30,
149
  n_fft=400,
 
166
  **kwargs,
167
  )
168
  self.max_frequency = max_frequency if max_frequency is not None else sampling_rate / 2
 
169
  self.batch_size = batch_size
170
  self.mel_filters = mel_filter_bank(
171
  num_frequency_bins=1 + n_fft // 2,
 
203
  chunk_length=self.chunk_length,
204
  overlap_seconds=overlap_seconds,
205
  sampling_rate=self.sampling_rate,
 
206
  encode_func=partial(
207
  super().__call__,
208
  truncation=truncation,
modeling_xy_tokenizer.py CHANGED
@@ -120,11 +120,11 @@ class VectorQuantizerConfig:
120
  # ----------------------------------------------- #
121
  # All Helper Modules (Copied from source) #
122
  # ----------------------------------------------- #
123
- def sinusoids(length, channels, max_timescale=10000):
124
  assert channels % 2 == 0
125
  log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
126
  inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
127
- scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
128
  return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
129
 
130
 
@@ -840,6 +840,7 @@ class XYTokenizerModel(XYTokenizerPreTrainedModel):
840
  self.enhanced_vocos = Vocos(**params['vocos_kwargs'])
841
  self.feature_extractor = params['feature_extractor_kwargs']
842
  # Store some config values for easier access
 
843
  self.nq = params['quantizer_kwargs']['num_quantizers']
844
 
845
  # Initialize weights and apply final processing
@@ -893,7 +894,7 @@ class XYTokenizerModel(XYTokenizerPreTrainedModel):
893
 
894
  # 1. Iterate through chunks and store intermediate results
895
  for chunk_features in features:
896
- code_duration_length = features.code_duration_length
897
  # Always use return_dict=True for easier access to named outputs
898
  chunk_output = self._encode(chunk_features, n_quantizers, return_dict=True)
899
  valid_code_lengths = torch.clamp(chunk_output.codes_lengths, 0, code_duration_length)
@@ -972,10 +973,8 @@ class XYTokenizerModel(XYTokenizerPreTrainedModel):
972
  ) -> Union[XYTokenizerEncodeOutput, Tuple]:
973
  input_mel = features['input_features'].to(self.device, dtype=self.dtype)
974
  mel_attention_mask = features['attention_mask'].to(self.device)
975
- input_lengths = features['input_lengths'].to(self.device).unsqueeze(1)
976
- mel_output_length = mel_attention_mask.sum(dim=-1).long().unsqueeze(1)
977
- mel_output_length = torch.cat((mel_output_length, input_lengths), dim=1).min(dim=1).values
978
-
979
  # --- Encoder Path ---
980
  semantic_encoder_output, semantic_encoder_output_length = self.semantic_encoder(input_mel, mel_output_length)
981
  semantic_adapter_output, _ = self.semantic_encoder_adapter(semantic_encoder_output, semantic_encoder_output_length)
@@ -983,8 +982,8 @@ class XYTokenizerModel(XYTokenizerPreTrainedModel):
983
 
984
  concated_channel = torch.cat([semantic_adapter_output, acoustic_encoder_output], dim=1)
985
 
986
- pre_rvq_adapter_output, _ = self.pre_rvq_adapter(concated_channel, acoustic_encoder_output_length)
987
- downsample_output, downsample_output_length = self.downsample(pre_rvq_adapter_output, acoustic_encoder_output_length)
988
 
989
  n_quantizers = n_quantizers or self.quantizer.num_quantizers
990
  zq, codes, vq_loss, _, quantizer_output_length = self.quantizer(downsample_output, downsample_output_length, n_quantizers=n_quantizers)
 
120
  # ----------------------------------------------- #
121
  # All Helper Modules (Copied from source) #
122
  # ----------------------------------------------- #
123
+ def sinusoids(length, channels, max_timescale=10000, device=None):
124
  assert channels % 2 == 0
125
  log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
126
  inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
127
+ scaled_time = torch.arange(length, device=device)[:, np.newaxis] * inv_timescales[np.newaxis, :]
128
  return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
129
 
130
 
 
840
  self.enhanced_vocos = Vocos(**params['vocos_kwargs'])
841
  self.feature_extractor = params['feature_extractor_kwargs']
842
  # Store some config values for easier access
843
+ self.encoder_downsample_rate = config.encoder_downsample_rate
844
  self.nq = params['quantizer_kwargs']['num_quantizers']
845
 
846
  # Initialize weights and apply final processing
 
894
 
895
  # 1. Iterate through chunks and store intermediate results
896
  for chunk_features in features:
897
+ code_duration_length = features.duration_size // self.encoder_downsample_rate
898
  # Always use return_dict=True for easier access to named outputs
899
  chunk_output = self._encode(chunk_features, n_quantizers, return_dict=True)
900
  valid_code_lengths = torch.clamp(chunk_output.codes_lengths, 0, code_duration_length)
 
973
  ) -> Union[XYTokenizerEncodeOutput, Tuple]:
974
  input_mel = features['input_features'].to(self.device, dtype=self.dtype)
975
  mel_attention_mask = features['attention_mask'].to(self.device)
976
+ mel_output_length = mel_attention_mask.sum(dim=-1).long()
977
+
 
 
978
  # --- Encoder Path ---
979
  semantic_encoder_output, semantic_encoder_output_length = self.semantic_encoder(input_mel, mel_output_length)
980
  semantic_adapter_output, _ = self.semantic_encoder_adapter(semantic_encoder_output, semantic_encoder_output_length)
 
982
 
983
  concated_channel = torch.cat([semantic_adapter_output, acoustic_encoder_output], dim=1)
984
 
985
+ pre_rvq_adapter_output, pre_rvq_adapter_output_length = self.pre_rvq_adapter(concated_channel, acoustic_encoder_output_length)
986
+ downsample_output, downsample_output_length = self.downsample(pre_rvq_adapter_output, pre_rvq_adapter_output_length)
987
 
988
  n_quantizers = n_quantizers or self.quantizer.num_quantizers
989
  zq, codes, vq_loss, _, quantizer_output_length = self.quantizer(downsample_output, downsample_output_length, n_quantizers=n_quantizers)
preprocessor_config.json CHANGED
@@ -8,7 +8,6 @@
8
  "padding_side": "right",
9
  "padding_value": 0.0,
10
  "sampling_rate": 16000,
11
- "encoder_downsample_rate": 1280,
12
  "return_attention_mask": true,
13
  "return_tensors": "pt"
14
  }
 
8
  "padding_side": "right",
9
  "padding_value": 0.0,
10
  "sampling_rate": 16000,
 
11
  "return_attention_mask": true,
12
  "return_tensors": "pt"
13
  }