wanglamao commited on
Commit
528efee
·
1 Parent(s): e55409d
Files changed (42) hide show
  1. app.py +206 -0
  2. data_utils/__init__.py +0 -0
  3. data_utils/audio_dataset_ark_audio.py +414 -0
  4. gpa_inference.py +293 -0
  5. models/__init__.py +0 -0
  6. models/bicodec_tokenizer/__init__.py +0 -0
  7. models/bicodec_tokenizer/base_model.py +87 -0
  8. models/bicodec_tokenizer/batch_processor.py +182 -0
  9. models/bicodec_tokenizer/models/__init__.py +0 -0
  10. models/bicodec_tokenizer/models/audio_tokenizer.py +164 -0
  11. models/bicodec_tokenizer/models/bicodec.py +248 -0
  12. models/bicodec_tokenizer/modules/blocks/layers.py +73 -0
  13. models/bicodec_tokenizer/modules/blocks/samper.py +115 -0
  14. models/bicodec_tokenizer/modules/blocks/vocos.py +373 -0
  15. models/bicodec_tokenizer/modules/encoder_decoder/feat_decoder.py +115 -0
  16. models/bicodec_tokenizer/modules/encoder_decoder/feat_encoder.py +107 -0
  17. models/bicodec_tokenizer/modules/encoder_decoder/wave_generator.py +88 -0
  18. models/bicodec_tokenizer/modules/fsq/finite_scalar_quantization.py +251 -0
  19. models/bicodec_tokenizer/modules/fsq/residual_fsq.py +355 -0
  20. models/bicodec_tokenizer/modules/speaker/__init__.py +0 -0
  21. models/bicodec_tokenizer/modules/speaker/ecapa_tdnn.py +267 -0
  22. models/bicodec_tokenizer/modules/speaker/perceiver_encoder.py +360 -0
  23. models/bicodec_tokenizer/modules/speaker/pooling_layers.py +298 -0
  24. models/bicodec_tokenizer/modules/speaker/speaker_encoder.py +136 -0
  25. models/bicodec_tokenizer/modules/vq/factorized_vector_quantize.py +187 -0
  26. models/bicodec_tokenizer/spark_detokenizer.py +106 -0
  27. models/bicodec_tokenizer/spark_tokenizer.py +244 -0
  28. models/bicodec_tokenizer/tokenizer_utils.py +44 -0
  29. models/bicodec_tokenizer/utils/__init__.py +0 -0
  30. models/bicodec_tokenizer/utils/audio.py +271 -0
  31. models/bicodec_tokenizer/utils/file.py +221 -0
  32. models/bicodec_tokenizer/utils/parse_options.sh +97 -0
  33. models/bicodec_tokenizer/utils/token_parser.py +187 -0
  34. models/glm_speech_tokenizer/__init__.py +0 -0
  35. models/glm_speech_tokenizer/batch_processor.py +182 -0
  36. models/glm_speech_tokenizer/configuration_whisper.py +37 -0
  37. models/glm_speech_tokenizer/generation_whisper.py +1828 -0
  38. models/glm_speech_tokenizer/modeling_whisper.py +0 -0
  39. models/glm_speech_tokenizer/speech_token_extractor.py +126 -0
  40. models/glm_speech_tokenizer/test_speech_token_extractor.py +136 -0
  41. models/glm_speech_tokenizer/utils.py +89 -0
  42. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+ import argparse
6
+ import librosa
7
+ import soundfile as sf
8
+
9
+ from gpa_inference import GPAInference
10
+
11
+ # Global inference object placeholder
12
+ inference = None
13
+
14
+
15
+ def preprocess_audio(audio_path):
16
+ """Ensure audio is 16kHz mono"""
17
+ if not audio_path:
18
+ return None
19
+ try:
20
+ # Load audio with librosa: automatically resamples to sr=16000 and converts to mono
21
+ y, _ = librosa.load(audio_path, sr=16000, mono=True)
22
+
23
+ # Save processed audio to a new file to avoid conflicts
24
+ dir_name = os.path.dirname(audio_path)
25
+ base_name = os.path.basename(audio_path)
26
+ name, ext = os.path.splitext(base_name)
27
+ new_path = os.path.join(dir_name, f"{name}_16k.wav")
28
+
29
+ sf.write(new_path, y, 16000)
30
+ print(f"Preprocessed audio saved to: {new_path}")
31
+ return new_path
32
+ except Exception as e:
33
+ print(f"Error processing audio {audio_path}: {e}")
34
+ return audio_path
35
+
36
+
37
+ # ======================== Interface Call Logic ========================
38
+
39
+ def process_stt(audio_path):
40
+ global inference
41
+ if inference is None:
42
+ return "Model not initialized."
43
+
44
+ if not audio_path:
45
+ return "Please upload audio first."
46
+
47
+ # Preprocess audio
48
+ audio_path = preprocess_audio(audio_path)
49
+
50
+ # Direct inference call
51
+ return inference.run_stt(audio_path=audio_path, do_sample=False)
52
+
53
+ def process_tts_a(text, ref_audio):
54
+ global inference
55
+ if inference is None:
56
+ return None
57
+
58
+ if not text or not ref_audio:
59
+ return None
60
+
61
+ # Preprocess audio
62
+ ref_audio = preprocess_audio(ref_audio)
63
+
64
+ # Direct inference call
65
+ return inference.run_tts(
66
+ task="tts-a",
67
+ output_filename="tts_output.wav",
68
+ text=text,
69
+ ref_audio_path=ref_audio,
70
+ temperature=0.8,
71
+ do_sample=True,
72
+ )
73
+
74
+ def process_vc(src_audio, ref_audio):
75
+ global inference
76
+ if inference is None:
77
+ return None
78
+
79
+ if not src_audio or not ref_audio:
80
+ return None
81
+
82
+ # Preprocess audio
83
+ src_audio = preprocess_audio(src_audio)
84
+ ref_audio = preprocess_audio(ref_audio)
85
+
86
+ # Direct inference call
87
+ return inference.run_vc(
88
+ source_audio_path=src_audio,
89
+ ref_audio_path=ref_audio,
90
+ output_filename="vc_output.wav",
91
+ )
92
+
93
+ # ======================== Gradio UI Layout ========================
94
+
95
+ # Use a soft, premium theme with indigo/slate colors to replace the default orange
96
+ theme = gr.themes.Soft(
97
+ primary_hue="indigo",
98
+ secondary_hue="slate",
99
+ neutral_hue="slate",
100
+ font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
101
+ )
102
+
103
+ with gr.Blocks(title="General Purpose Audio System", theme=theme) as demo:
104
+ gr.Markdown("# General Purpose Audio System")
105
+ gr.Markdown("STT, TTS, and VC full-feature demo interface based on GPAEngine.")
106
+
107
+ with gr.Tabs():
108
+ # --- STT Tab ---
109
+ with gr.TabItem("🎙️ Speech to Text (STT)"):
110
+ with gr.Row():
111
+ stt_input = gr.Audio(label="Input Audio", type="filepath")
112
+ stt_output = gr.Textbox(label="Recognition Result", placeholder="Recognition result will be displayed here in real-time...", lines=5)
113
+ stt_btn = gr.Button("Start Recognition", variant="primary")
114
+ stt_btn.click(process_stt, inputs=stt_input, outputs=stt_output)
115
+
116
+ # --- TTS-A Tab ---
117
+ with gr.TabItem("👤 Text to Speech (TTS)"):
118
+ with gr.Row():
119
+ with gr.Column():
120
+ ttsa_text = gr.Textbox(label="Synthesis Text", value="Hello, I am generated by voice cloning.")
121
+ ttsa_ref = gr.Audio(label="Reference Audio (Voice Source)", type="filepath")
122
+ ttsa_output = gr.Audio(label="Synthesis Result")
123
+ ttsa_btn = gr.Button("Synthesize Now", variant="primary")
124
+ ttsa_btn.click(process_tts_a, inputs=[ttsa_text, ttsa_ref], outputs=ttsa_output)
125
+
126
+ # --- VC Tab ---
127
+ with gr.TabItem("🎭 Voice Conversion (VC)"):
128
+ with gr.Row():
129
+ with gr.Column():
130
+ vc_src = gr.Audio(label="Source Audio (Content Source)", type="filepath")
131
+ vc_ref = gr.Audio(label="Reference Audio (Voice Source)", type="filepath")
132
+ vc_output = gr.Audio(label="Conversion Result")
133
+ vc_btn = gr.Button("Start Conversion", variant="primary")
134
+ vc_btn.click(process_vc, inputs=[vc_src, vc_ref], outputs=vc_output)
135
+
136
+
137
+ def parse_args():
138
+ parser = argparse.ArgumentParser(description="GPA Audio System GUI")
139
+
140
+ # Model Paths
141
+ parser.add_argument(
142
+ "--tokenizer_path",
143
+ type=str,
144
+ default="/data3/gpa_ckpt/gpa_final/glm-4-voice-tokenizer",
145
+ help="Path to GLM4 tokenizer",
146
+ )
147
+ parser.add_argument(
148
+ "--text_tokenizer_path",
149
+ type=str,
150
+ default="/data3/gpa_ckpt/gpa_final",
151
+ help="Path to text tokenizer",
152
+ )
153
+ parser.add_argument(
154
+ "--bicodec_tokenizer_path",
155
+ type=str,
156
+ default="/data3/gpa_ckpt/gpa_final/BiCodec/",
157
+ help="Path to BiCodec tokenizer",
158
+ )
159
+ parser.add_argument(
160
+ "--gpa_model_path",
161
+ type=str,
162
+ default="/data3/gpa_ckpt/gpa_final",
163
+ help="Path to GPA model",
164
+ )
165
+
166
+ # System Config
167
+ parser.add_argument(
168
+ "--output_dir",
169
+ type=str,
170
+ default="./output_gui",
171
+ help="Directory to save output files",
172
+ )
173
+ parser.add_argument(
174
+ "--device",
175
+ type=str,
176
+ default="cuda" if torch.cuda.is_available() else "cpu",
177
+ help="Device to use",
178
+ )
179
+
180
+ # Server Config
181
+ parser.add_argument(
182
+ "--server_name", type=str, default="0.0.0.0", help="Address for Gradio server"
183
+ )
184
+ parser.add_argument(
185
+ "--server_port", type=int, default=7868, help="Port for Gradio server"
186
+ )
187
+
188
+ return parser.parse_args()
189
+
190
+ args = parse_args()
191
+
192
+ # Instantiate Model
193
+ print(f"Initializing GPA Inference System on {args.device}...")
194
+ os.makedirs(args.output_dir, exist_ok=True)
195
+
196
+ inference = GPAInference(
197
+ tokenizer_path=args.tokenizer_path,
198
+ text_tokenizer_path=args.text_tokenizer_path,
199
+ bicodec_tokenizer_path=args.bicodec_tokenizer_path,
200
+ gpa_model_path=args.gpa_model_path,
201
+ output_dir=args.output_dir,
202
+ device=args.device,
203
+ )
204
+
205
+ # Launch Gradio Demo
206
+ demo.queue().launch()
data_utils/__init__.py ADDED
File without changes
data_utils/audio_dataset_ark_audio.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from models.bicodec_tokenizer.spark_tokenizer import SparkTokenizer
4
+ from models.glm_speech_tokenizer.speech_token_extractor import SpeechTokenExtractor
5
+ from models.glm_speech_tokenizer.modeling_whisper import WhisperVQEncoder
6
+ from transformers import PreTrainedTokenizer,AutoTokenizer,WhisperFeatureExtractor
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import logging
10
+ from typing import List, Dict, Any, Literal, Optional, Union
11
+ from datasets import load_dataset
12
+ from torch.utils.data import DataLoader
13
+
14
+ def has_punctuation(text: str) -> bool:
15
+ # 包含中英文符号
16
+ pattern = r"[,。!?;:()“”‘’、,.!?;:()\[\]{}\"']"
17
+ return bool(re.search(pattern, text))
18
+
19
+ ALL_TASKS = ["stt", "tts-a", "vc"]
20
+
21
+
22
+ class ark_infer_processor:
23
+ def __init__(
24
+ self,
25
+ glm_tokenizer: SpeechTokenExtractor,
26
+ bicodec_tokenizer: SparkTokenizer,
27
+ text_tokenizer: PreTrainedTokenizer,
28
+ max_length: int = 512,
29
+ glm_semantic_token_offset: int = 151727,
30
+ semantic_token_offset: int = 172207,
31
+ global_token_offset: int = 168111,
32
+ audio_path_name: str = "audio",
33
+ device: str = "cpu",
34
+ ):
35
+ self.glm_tokenizer = glm_tokenizer
36
+ self.bicodec_tokenizer = bicodec_tokenizer
37
+ self.text_tokenizer = text_tokenizer
38
+ self.max_length = max_length
39
+ self.glm_semantic_token_offset = glm_semantic_token_offset
40
+ self.semantic_token_offset = semantic_token_offset
41
+ self.global_token_offset = global_token_offset
42
+ self.device = device
43
+ self.audio_path_name = audio_path_name
44
+
45
+ def _process_example_stt(self, audio_path: str):
46
+
47
+ ##target 音频
48
+ with torch.no_grad():
49
+ glm_semantic_tokens = self.glm_tokenizer.extract([audio_path])
50
+ glm_semantic_tokens = torch.as_tensor(glm_semantic_tokens, device="cpu", dtype=torch.long)
51
+
52
+ semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens']
53
+ glm_semantic_tokens_list = (
54
+ (glm_semantic_tokens + self.glm_semantic_token_offset).cpu().tolist()[0]
55
+ )
56
+ semantic_tokens_list = (
57
+ (semantic_tokens + self.semantic_token_offset).cpu().tolist()[0]
58
+ )
59
+ input_ids = (
60
+ self.text_tokenizer.encode("<|start_glm_token|>")
61
+ + glm_semantic_tokens_list
62
+ + self.text_tokenizer.encode("<|end_glm_token|>")
63
+ + self.text_tokenizer.encode("<|start_semantic_token|>")
64
+ + semantic_tokens_list
65
+ + self.text_tokenizer.encode("<|end_semantic_token|>")
66
+ + self.text_tokenizer.encode("<|start_content|>")
67
+ )
68
+ attention_mask = [1] * (len(input_ids))
69
+ return input_ids, attention_mask
70
+
71
+ def _process_example_tts_a(self, text: str, ref_audio_path: str):
72
+ with torch.no_grad():
73
+ global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])['global_tokens']
74
+ all_text = "<|start_content|>" + text + "<|end_content|>"
75
+ global_tokens_list = (
76
+ (global_tokens + self.global_token_offset).cpu().tolist()[0][0]
77
+ )
78
+ text_tokens = self.text_tokenizer(
79
+ all_text, truncation=True, max_length=self.max_length
80
+ )
81
+ input_ids = (
82
+ self.text_tokenizer.encode("<|start_global_token|>")
83
+ + global_tokens_list
84
+ + self.text_tokenizer.encode("<|end_global_token|>")
85
+ + text_tokens["input_ids"]
86
+ )
87
+ attention_mask = [1] * len(input_ids)
88
+ return input_ids, attention_mask
89
+
90
+ def _process_example_vc(self, audio_path: str, ref_audio_path: str):
91
+ with torch.no_grad():
92
+ semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens']
93
+ new_global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])['global_tokens']
94
+ semantic_tokens_list = (
95
+ (semantic_tokens + self.semantic_token_offset).cpu().tolist()[0]
96
+ )
97
+ new_global_tokens_list = (
98
+ (new_global_tokens + self.global_token_offset).cpu().tolist()[0][0]
99
+ )
100
+ all_str = (
101
+ "<|start_global_token|>"
102
+ + self.text_tokenizer.decode(new_global_tokens_list)
103
+ + "<|end_global_token|>"
104
+ + "<|start_semantic_token|>"
105
+ + self.text_tokenizer.decode(semantic_tokens_list)
106
+ + "<|end_semantic_token|>"
107
+ + "<|end_content|>"
108
+ )
109
+
110
+ inputs = self.text_tokenizer(all_str)
111
+ input_ids = inputs["input_ids"]
112
+ attention_mask = inputs["attention_mask"]
113
+ return input_ids, attention_mask
114
+
115
+ def process_input(
116
+ self,
117
+ task: Literal["stt", "tts-a", "vc"],
118
+ audio_path: str | None = None,
119
+ ref_audio_path: str | None = None,
120
+ text: str | None = None,
121
+ ):
122
+ """加载指定音频、特征并根据任务类型返回 token 化结果。"""
123
+
124
+ if task == "stt":
125
+ assert audio_path is not None
126
+ input_ids, attention_mask = self._process_example_stt(audio_path)
127
+ elif task == "tts-a":
128
+ assert ref_audio_path is not None and text is not None
129
+ input_ids, attention_mask = self._process_example_tts_a(
130
+ text, ref_audio_path
131
+ )
132
+ elif task == "vc":
133
+ assert audio_path is not None and ref_audio_path is not None
134
+ input_ids, attention_mask = self._process_example_vc(
135
+ audio_path, ref_audio_path
136
+ )
137
+ else:
138
+ raise ValueError(
139
+ f"Unsupported task: {task}, all supported tasks: {ALL_TASKS}"
140
+ )
141
+ return {
142
+ "input_ids": input_ids,
143
+ "attention_mask": attention_mask,
144
+ }
145
+
146
+
147
+ class ark_processor:
148
+ def __init__(self,
149
+ glm_tokenizer: SpeechTokenExtractor,
150
+ bicodec_tokenizer: SparkTokenizer,
151
+ text_tokenizer:PreTrainedTokenizer,
152
+ max_length:int = 512,
153
+ glm_semantic_token_offset:int = 151727,
154
+ semantic_token_offset: int =172207,
155
+ global_token_offset: int =168111,
156
+ audio_path_name:str = "audio",
157
+ device:str ='cpu'):
158
+ self.glm_tokenizer = glm_tokenizer
159
+ self.bicodec_tokenizer = bicodec_tokenizer
160
+ self.text_tokenizer = text_tokenizer
161
+ self.max_length = max_length
162
+ self.glm_semantic_token_offset =glm_semantic_token_offset
163
+ self.semantic_token_offset=semantic_token_offset
164
+ self.global_token_offset=global_token_offset
165
+ self.device = device
166
+ self.audio_path_name =audio_path_name
167
+
168
+ def process_example(self, example: Dict[str, Any]):
169
+ """
170
+ 这个函数由多个CPU进程并行执行。
171
+ 它负责加载、重采样和对单个样本进行特征提取/分词。
172
+ """
173
+ task = example.get("task", "stt")
174
+ audio_path = example.get(self.audio_path_name, "")
175
+ ref_audio_path = example.get("ref_audio", "")
176
+ vc_audio = example.get("vc_audio", "")
177
+ text = example.get("text", "")
178
+
179
+ if task == "stt":
180
+ ##target 音频
181
+ with torch.no_grad():
182
+ glm_semantic_tokens = self.glm_tokenizer.extract([audio_path])
183
+ glm_semantic_tokens = torch.as_tensor(glm_semantic_tokens, device="cpu", dtype=torch.long)
184
+
185
+ semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens']
186
+ glm_semantic_tokens_list = (glm_semantic_tokens + self.glm_semantic_token_offset).cpu().tolist()[0]
187
+ semantic_tokens_list = (semantic_tokens + self.semantic_token_offset).cpu().tolist()[0]
188
+ # print(f"len of semantic is {len(semantic_tokens_list)}")
189
+ ##对text进行token
190
+ text_tokens = self.text_tokenizer(text, truncation=True, max_length=self.max_length)
191
+
192
+ input_ids = self.text_tokenizer.encode("<|start_glm_token|>") + glm_semantic_tokens_list + self.text_tokenizer.encode("<|end_glm_token|>") \
193
+ + self.text_tokenizer.encode("<|start_semantic_token|>") + semantic_tokens_list + self.text_tokenizer.encode(
194
+ "<|end_semantic_token|>") \
195
+ + self.text_tokenizer.encode("<|start_content|>") + text_tokens["input_ids"] + self.text_tokenizer.encode("<|end_content|>") \
196
+ + self.text_tokenizer.encode("<|im_end|>")
197
+ attention_mask = [1] * (len(input_ids))
198
+ labels = [-100] * (len(semantic_tokens_list) + 5 + len(glm_semantic_tokens_list)) + text_tokens["input_ids"] + self.text_tokenizer.encode(
199
+ "<|end_content|>") + self.text_tokenizer.encode("<|im_end|>")
200
+
201
+ elif task == "tts-a":
202
+ with torch.no_grad():
203
+ semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens']
204
+ global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])['global_tokens']
205
+ all_text = "<|start_content|>" + text + "<|end_content|>"
206
+ global_tokens_list = (global_tokens + self.global_token_offset).cpu().tolist()[0][0]
207
+ text_tokens = self.text_tokenizer(all_text, truncation=True, max_length=self.max_length)
208
+ semantic_tokens_list = (semantic_tokens + self.semantic_token_offset).cpu().tolist()[0]
209
+ input_ids = self.text_tokenizer.encode("<|start_global_token|>") + global_tokens_list + self.text_tokenizer.encode(
210
+ "<|end_global_token|>") + text_tokens["input_ids"] + semantic_tokens_list + self.text_tokenizer.encode("<|im_end|>")
211
+ attention_mask = [1] * len(input_ids)
212
+ labels = [-100] * (len(text_tokens["input_ids"]) + 2 + len(global_tokens_list)) + semantic_tokens_list + self.text_tokenizer.encode("<|im_end|>")
213
+
214
+ elif task == "vc":
215
+ with torch.no_grad():
216
+ semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens']
217
+ global_tokens = self.bicodec_tokenizer.tokenize([audio_path])['global_tokens']
218
+ # global_tokens, semantic_tokens=self.bicodec_tokenizer.tokenize(audio_path=audio_path)
219
+ # new_global_tokens, new_semantic_tokens=self.bicodec_tokenizer.tokenize(vc_audio,ref_audio_path)
220
+ new_semantic_tokens = self.bicodec_tokenizer.tokenize([vc_audio])['semantic_tokens']
221
+ new_global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])['global_tokens']
222
+
223
+ global_tokens_list = (global_tokens + self.global_token_offset).cpu().tolist()[0][0]
224
+ semantic_tokens_list = (semantic_tokens + self.semantic_token_offset).cpu().tolist()[0]
225
+ new_global_tokens_list = (new_global_tokens + self.global_token_offset).cpu().tolist()[0][0]
226
+ new_semantic_tokens_list = (new_semantic_tokens + self.semantic_token_offset).cpu().tolist()[0]
227
+ all_str = "<|start_global_token|>" + self.text_tokenizer.decode(new_global_tokens_list) + "<|end_global_token|>" + "<|start_semantic_token|>" + self.text_tokenizer.decode(
228
+ semantic_tokens_list) + "<|end_semantic_token|>" + "<|end_content|>" + self.text_tokenizer.decode(new_semantic_tokens_list) + "<|im_end|>"
229
+
230
+ ##add token and mask
231
+ inputs = self.text_tokenizer(all_str)
232
+ input_ids = inputs['input_ids']
233
+ attention_mask = inputs['attention_mask']
234
+ labels = [-100] * (5 + len(new_global_tokens_list) + len(semantic_tokens_list)) + new_semantic_tokens_list + self.text_tokenizer.encode("<|im_end|>")
235
+ else:
236
+ ##默认走stt
237
+ with torch.no_grad():
238
+ glm_semantic_tokens = self.glm_tokenizer.extract([audio_path])
239
+ glm_semantic_tokens = torch.as_tensor(glm_semantic_tokens, device="cpu", dtype=torch.long)
240
+
241
+ semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens']
242
+ glm_semantic_tokens_list = (glm_semantic_tokens+self.glm_semantic_token_offset).cpu().tolist()[0]
243
+ semantic_tokens_list = (semantic_tokens+self.semantic_token_offset).cpu().tolist()[0]
244
+ # print(f"len of semantic is {len(semantic_tokens_list)}")
245
+ ##对text进行token
246
+ text_tokens = self.text_tokenizer(text, truncation=True, max_length=self.max_length)
247
+
248
+ input_ids = self.text_tokenizer.encode("<|start_glm_token|>")+ glm_semantic_tokens_list + self.text_tokenizer.encode("<|end_glm_token|>") \
249
+ + self.text_tokenizer.encode("<|start_semantic_token|>")+ semantic_tokens_list + self.text_tokenizer.encode("<|end_semantic_token|>") \
250
+ + text_tokens["input_ids"] \
251
+ + self.text_tokenizer.encode("<|im_end|>")
252
+ attention_mask = [1]*(len(semantic_tokens_list)+4+len(glm_semantic_tokens_list)) +text_tokens["attention_mask"] +[1]
253
+ labels = [-100]*(len(semantic_tokens_list)+4+len(glm_semantic_tokens_list))+ text_tokens["input_ids"]+ self.text_tokenizer.encode("<|im_end|>")
254
+ return {
255
+ "input_ids": input_ids,
256
+ "attention_mask": attention_mask,
257
+ "labels": labels,
258
+ }
259
+
260
+
261
+ def create_tts_collate_fn(
262
+ pad_token_id: int,
263
+ processor, # ark_processor
264
+ max_length: Optional[int]=None,# 传入你想要的截断上限,例如 512
265
+ truncation_side: str = "right" # "right" 或 "left",默认右截断
266
+ ):
267
+ """
268
+ 手动填充 + 可选截断的 collate_fn 工厂。
269
+
270
+ 参数:
271
+ pad_token_id: 用于 input_ids 的 pad 值
272
+ processor: 你的 ark_processor,需提供 .process_example()
273
+ max_length: 若提供,则对每个样本在拼批前先截断到该长度
274
+ truncation_side: "right" | "left",决定从哪侧截断
275
+ """
276
+ label_pad_value = -100
277
+ attention_mask_pad_value = 0
278
+
279
+ def _truncate_1d(x: torch.Tensor, keep_len: int, side: str) -> torch.Tensor:
280
+ if x.numel() <= keep_len:
281
+ return x
282
+ if side == "right":
283
+ return x[:keep_len]
284
+ elif side == "left":
285
+ return x[-keep_len:]
286
+ else:
287
+ raise ValueError(f"Unsupported truncation_side: {side}")
288
+
289
+ def _to_long_tensor(x) -> torch.Tensor:
290
+ if isinstance(x, torch.Tensor):
291
+ return x.detach().clone().long()
292
+ return torch.tensor(x, dtype=torch.long)
293
+
294
+ def collate_fn(examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
295
+ # 1) 预处理(过滤空样本)
296
+ proc = [processor.process_example(ex) for ex in examples if ex]
297
+ proc = [d for d in proc if d and ("input_ids" in d) and ("attention_mask" in d) and ("labels" in d)]
298
+
299
+ if len(proc) == 0:
300
+ # 返回空批,避免 DataLoader 崩溃
301
+ return {
302
+ "input_ids": torch.empty(0, dtype=torch.long),
303
+ "attention_mask": torch.empty(0, dtype=torch.long),
304
+ "labels": torch.empty(0, dtype=torch.long),
305
+ }
306
+
307
+ # 2) 样本级截断(如果设置了 max_length)
308
+ if max_length is not None:
309
+ trimmed = []
310
+ for ex in proc:
311
+ ids = _to_long_tensor(ex["input_ids"])
312
+ mask = _to_long_tensor(ex["attention_mask"])
313
+ labs = _to_long_tensor(ex["labels"])
314
+
315
+ keep_len = min(max_length, ids.numel())
316
+ ids = _truncate_1d(ids, keep_len, truncation_side)
317
+ mask = _truncate_1d(mask, keep_len, truncation_side)
318
+ labs = _truncate_1d(labs, keep_len, truncation_side)
319
+
320
+ trimmed.append({"input_ids": ids, "attention_mask": mask, "labels": labs})
321
+ proc = trimmed
322
+
323
+ # 3) 计算本批最大长度(截断后再取最大)
324
+ max_len_in_batch = max(int(len(ex["input_ids"])) for ex in proc)
325
+
326
+ # 4) 逐样本右侧 pad 到 batch 最大长度
327
+ padded_input_ids_list = []
328
+ padded_attention_mask_list = []
329
+ padded_labels_list = []
330
+
331
+ for ex in proc:
332
+ ids = _to_long_tensor(ex["input_ids"])
333
+ mask = _to_long_tensor(ex["attention_mask"])
334
+ labs = _to_long_tensor(ex["labels"])
335
+
336
+ need = max_len_in_batch - ids.numel()
337
+ if need < 0:
338
+ # 极端情况:有人为 max_length=None 时超长样本溢出
339
+ keep_len = max_len_in_batch
340
+ ids = _truncate_1d(ids, keep_len, "right")
341
+ mask = _truncate_1d(mask, keep_len, "right")
342
+ labs = _truncate_1d(labs, keep_len, "right")
343
+ need = 0
344
+
345
+ pad_dims = (0, need)
346
+ ids = F.pad(ids, pad_dims, mode="constant", value=pad_token_id)
347
+ mask = F.pad(mask, pad_dims, mode="constant", value=attention_mask_pad_value)
348
+ labs = F.pad(labs, pad_dims, mode="constant", value=label_pad_value)
349
+
350
+ padded_input_ids_list.append(ids)
351
+ padded_attention_mask_list.append(mask)
352
+ padded_labels_list.append(labs)
353
+
354
+ # 5) 堆叠成批
355
+ batch = {
356
+ "input_ids": torch.stack(padded_input_ids_list, dim=0),
357
+ "attention_mask": torch.stack(padded_attention_mask_list, dim=0),
358
+ "labels": torch.stack(padded_labels_list, dim=0),
359
+ }
360
+ return batch
361
+
362
+ return collate_fn
363
+
364
+ if __name__ == "__main__":
365
+ device = "cuda:0"
366
+ bicodec_audio_tokenizer_path = "/data/arki_production/model/SparkAudio/Spark-TTS-0___5B/"
367
+ glm_speech_tokenizer_path = "/data/yumu/model/glm-4-voice-tokenizer"
368
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(glm_speech_tokenizer_path)
369
+ audio_model = WhisperVQEncoder.from_pretrained(glm_speech_tokenizer_path).eval().to(device)
370
+ glm_tokenizer = SpeechTokenExtractor(model=audio_model, feature_extractor=feature_extractor, device=device)
371
+
372
+ text_tokenizer = AutoTokenizer.from_pretrained("/data/yumu/model/ark_audio_v1_0_3_b",trust_remote_code=True)
373
+ bicodec_tokenizer = SparkTokenizer(model_path=bicodec_audio_tokenizer_path, device=device)
374
+ # 配置项
375
+ DATASET_PATH = "/data/yumu/glm_asr_vllm/test/data/test_meeting.jsonl"
376
+ MAX_LENGTH = 4096
377
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
378
+
379
+ print(f"将使用设备: {DEVICE}")
380
+
381
+
382
+ # --- 2. 加载流式数据集 ---
383
+
384
+ print(f"以流式方式加载数据集 '{DATASET_PATH}'...")
385
+
386
+
387
+ streaming_dataset = load_dataset("json", data_files=DATASET_PATH, streaming=True)['train']
388
+ # --- 4. 构建数据处理流水线 (Pipeline) ---
389
+
390
+ print("正在对数据流进行shuffle,buffer_size=1000...")
391
+ shuffled_dataset = streaming_dataset.shuffle(buffer_size=10000, seed=42)
392
+ processor = ark_processor(
393
+ glm_tokenizer=glm_tokenizer,
394
+ bicodec_tokenizer=bicodec_tokenizer,
395
+ text_tokenizer=text_tokenizer,
396
+ device = DEVICE,
397
+ audio_path_name="audio")
398
+ collate_fn = create_tts_collate_fn(text_tokenizer.pad_token_id,processor,max_length=4096)
399
+ # 创建最终的DataLoader
400
+ data_loader = DataLoader(
401
+ shuffled_dataset,
402
+ batch_size=10, # 根据你的GPU显存和模型大小调整
403
+ collate_fn=collate_fn,
404
+ num_workers=0 # DataLoader的worker,负责从打乱后的流中拉取数据
405
+ )
406
+ print("\n--- 高性能流式 DataLoader 演示 ---")
407
+ print("将从DataLoader中获取并展示第一个批次的数据:\n")
408
+ first_batch = next(iter(data_loader))
409
+
410
+ print("成功获取第一个批次!数据已在collate_fn中填充。")
411
+ for key, value in first_batch.items():
412
+ if value is not None:
413
+ # print(f" - {key}: shape={value.shape}, dtype={value.dtype}")
414
+ print(f" - {key}: shape={value.shape}, dtype={value}")
gpa_inference.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ import soundfile as sf
5
+ import re
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, WhisperFeatureExtractor
7
+ import numpy as np
8
+
9
+ from models.bicodec_tokenizer.spark_tokenizer import SparkTokenizer
10
+ from models.bicodec_tokenizer.spark_detokenizer import SparkDeTokenizer
11
+
12
+ from models.glm_speech_tokenizer.speech_token_extractor import SpeechTokenExtractor
13
+ from models.glm_speech_tokenizer.modeling_whisper import WhisperVQEncoder
14
+
15
+ from data_utils.audio_dataset_ark_audio import ark_infer_processor
16
+
17
+ class GPAInference:
18
+ def __init__(self, tokenizer_path, text_tokenizer_path, bicodec_tokenizer_path, gpa_model_path, output_dir, device):
19
+ self.tokenizer_path = tokenizer_path
20
+ self.text_tokenizer_path = text_tokenizer_path
21
+ self.bicodec_tokenizer_path = bicodec_tokenizer_path
22
+ self.gpa_model_path = gpa_model_path
23
+ self.output_dir = output_dir
24
+ self.device = device
25
+
26
+ print(f"Using device: {self.device}")
27
+ self._load_models()
28
+
29
+ def _load_models(self):
30
+ print("Loading tokenizers...")
31
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(self.tokenizer_path)
32
+ audio_model = WhisperVQEncoder.from_pretrained(self.tokenizer_path).eval().to(self.device)
33
+ self.glm_tokenizer = SpeechTokenExtractor(model=audio_model, feature_extractor=feature_extractor, device=self.device)
34
+ self.text_tokenizer = AutoTokenizer.from_pretrained(
35
+ self.text_tokenizer_path,
36
+ trust_remote_code=True
37
+ )
38
+
39
+ self.bicodec_tokenizer = SparkTokenizer(model_path=self.bicodec_tokenizer_path, device=self.device)
40
+ self.bicodec_detokenizer = SparkDeTokenizer(model_path=self.bicodec_tokenizer_path, device=self.device)
41
+ self.processor = ark_infer_processor(
42
+ glm_tokenizer=self.glm_tokenizer,
43
+ bicodec_tokenizer=self.bicodec_tokenizer,
44
+ text_tokenizer=self.text_tokenizer,
45
+ device=self.device,
46
+ audio_path_name="audio",
47
+ )
48
+
49
+ print("Loading model...")
50
+ self.model = AutoModelForCausalLM.from_pretrained(
51
+ self.gpa_model_path,
52
+ trust_remote_code=True
53
+ ).to(self.device)
54
+
55
+ def generate(self, inputs, **kwargs):
56
+ """
57
+ Base generation method that accepts dynamic generation parameters.
58
+ """
59
+ for k in inputs:
60
+ if isinstance(inputs[k], (list, np.ndarray)):
61
+ inputs[k] = torch.tensor(inputs[k]).unsqueeze(0).to(self.device)
62
+ elif isinstance(inputs[k], torch.Tensor):
63
+ inputs[k] = inputs[k].unsqueeze(0).to(self.device)
64
+
65
+ # Default generation config
66
+ generation_config = {
67
+ "max_new_tokens": 1000,
68
+ "do_sample": False,
69
+ "eos_token_id": self.text_tokenizer.convert_tokens_to_ids("<|im_end|>"),
70
+ }
71
+
72
+ # Override defaults with any passed kwargs
73
+ generation_config.update(kwargs)
74
+
75
+ # Remove keys that might be None if passed from args mistakenly
76
+ generation_config = {k: v for k, v in generation_config.items() if v is not None}
77
+ print(f"Generation config: {generation_config}")
78
+
79
+ outputs = self.model.generate(
80
+ input_ids=inputs["input_ids"],
81
+ attention_mask=inputs["attention_mask"],
82
+ **generation_config
83
+ )
84
+ return outputs
85
+
86
+ def run_stt(self, audio_path, **kwargs):
87
+ if not audio_path:
88
+ raise ValueError("audio_path is required for STT")
89
+
90
+ print("\n--- Speech to Text (STT) ---")
91
+
92
+ inputs = self.processor.process_input(
93
+ task="stt",
94
+ audio_path=audio_path,
95
+ )
96
+
97
+ # recommend hyperparameters for TTS
98
+ kwargs = {
99
+ "max_new_tokens": 512,
100
+ "do_sample": False,
101
+ }
102
+
103
+ # Pass generation arguments (temperature, etc.) to generate
104
+ outputs = self.generate(inputs, **kwargs)
105
+ text = self.text_tokenizer.decode(outputs[0].tolist())
106
+
107
+ if "<|start_content|>" in text:
108
+ return text.split("<|start_content|>")[1].replace("<|im_end|>","").replace("<|end_content|>","")
109
+ else:
110
+ return text.replace("<|im_end|>","")
111
+
112
+ def run_tts(self, task, output_filename, text, ref_audio_path, **kwargs):
113
+ """
114
+ gen_kwargs: dict, parameters for model.generate (temp, top_p, etc.)
115
+ """
116
+ if not text:
117
+ raise ValueError("text is required for TTS")
118
+
119
+ # Check ref_audio_path requirement based on task
120
+ if task == "tts-a" and not ref_audio_path:
121
+ raise ValueError(f"ref_audio_path is required for {task}")
122
+
123
+ # recommend hyperparameters for TTS
124
+ kwargs = {
125
+ "max_new_tokens": 512,
126
+ "temperature": 0.2,
127
+ "repetition_penalty": 1.2,
128
+ "do_sample": True,
129
+ }
130
+
131
+ print(f"\n--- {task.upper()} ---")
132
+ output_path = os.path.join(self.output_dir, output_filename)
133
+
134
+ # Pass processor specific args (e.g. emotion, pitch) here
135
+ inputs = self.processor.process_input(
136
+ task=task,
137
+ ref_audio_path=ref_audio_path,
138
+ text=text,
139
+ )
140
+
141
+ # Pass generation specific args (e.g. temperature) here
142
+ # Note: Original code hardcoded temperature=0.8 for TTS, we use gen_kwargs or fallback to generate defaults
143
+ outputs = self.generate(inputs, **kwargs)
144
+
145
+ text_output = self.text_tokenizer.decode(outputs[0].tolist())
146
+
147
+ if "<|end_content|>" in text_output:
148
+ content = text_output.split("<|end_content|>")[1]
149
+ else:
150
+ print("Warning: <|end_content|> not found")
151
+ content = text_output
152
+
153
+ audio_ids = re.findall(r"<\|bicodec_semantic_(\d+)\|>", content)
154
+ audio_list = [int(x) for x in audio_ids]
155
+
156
+ if ref_audio_path:
157
+ global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])['global_tokens']
158
+ else:
159
+ global_tokens = torch.zeros((1, 32), dtype=torch.long).to(self.device)
160
+
161
+ req = {
162
+ "global_tokens": global_tokens,
163
+ "semantic_tokens": torch.tensor(audio_list).unsqueeze(0).to(self.device),
164
+ }
165
+ out = self.bicodec_detokenizer.detokenize(**req)
166
+ reconstructed_wav = out.detach().cpu().float().squeeze().numpy()
167
+ # Simple DC offset removal
168
+ if reconstructed_wav.size > 0:
169
+ reconstructed_wav -= reconstructed_wav.mean()
170
+
171
+ sf.write(output_path, reconstructed_wav, 16000)
172
+ print(f"Saved output to {output_path}")
173
+ return 16000, reconstructed_wav
174
+
175
+ def run_vc(
176
+ self,
177
+ source_audio_path,
178
+ ref_audio_path,
179
+ output_filename="output_gpa_vc.wav",
180
+ **kwargs,
181
+ ):
182
+ if not source_audio_path:
183
+ raise ValueError("source_audio_path is required for VC")
184
+ if not ref_audio_path:
185
+ raise ValueError("ref_audio_path is required for VC")
186
+
187
+ print("\n--- Voice Conversion (VC) ---")
188
+ output_path = os.path.join(self.output_dir, output_filename)
189
+
190
+ inputs = self.processor.process_input(
191
+ task="vc",
192
+ audio_path=source_audio_path,
193
+ ref_audio_path=ref_audio_path,
194
+ )
195
+
196
+ outputs = self.generate(inputs, **kwargs)
197
+ text_output = self.text_tokenizer.decode(outputs[0].tolist())
198
+
199
+ if "<|end_content|>" in text_output:
200
+ content = text_output.split("<|end_content|>")[1]
201
+ else:
202
+ content = text_output
203
+
204
+ audio_ids = re.findall(r"<\|bicodec_semantic_(\d+)\|>", content)
205
+ audio_list = [int(x) for x in audio_ids]
206
+
207
+ global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])['global_tokens']
208
+
209
+ req = {
210
+ "global_tokens": global_tokens,
211
+ "semantic_tokens": torch.tensor(audio_list).unsqueeze(0).to(self.device),
212
+ }
213
+ out = self.bicodec_detokenizer.detokenize(**req)
214
+ reconstructed_wav = out.detach().cpu().float().squeeze().numpy()
215
+ if reconstructed_wav.size > 0:
216
+ reconstructed_wav -= reconstructed_wav.mean()
217
+
218
+ sf.write(output_path, reconstructed_wav, 16000)
219
+ print(f"Saved VC output to {output_path}")
220
+ return 16000, reconstructed_wav
221
+
222
+
223
+ def parse_args():
224
+ parser = argparse.ArgumentParser(description="GPA Inference Script")
225
+
226
+ # Paths
227
+ parser.add_argument("--tokenizer_path", type=str, default="/nasdata/model/gpa/glm-4-voice-tokenizer", help="Path to GLM4 tokenizer")
228
+ parser.add_argument("--text_tokenizer_path", type=str, default="/nasdata/model/gpa", help="Path to text tokenizer")
229
+ parser.add_argument("--bicodec_tokenizer_path", type=str, default="/nasdata/model/gpa/BiCodec/", help="Path to BiCodec tokenizer")
230
+ parser.add_argument("--gpa_model_path", type=str, default="/nasdata/model/gpa", help="Path to GPA model")
231
+
232
+ # Audio inputs
233
+ parser.add_argument(
234
+ "--ref_audio_path", type=str, default=None, help="Reference audio path"
235
+ )
236
+ parser.add_argument(
237
+ "--src_audio_path", type=str, default=None, help="Source audio path for VC/STT"
238
+ )
239
+
240
+ # Output
241
+ parser.add_argument("--output_dir", type=str, default=".", help="Directory to save output files")
242
+
243
+ # Device
244
+ default_device = "cuda" if torch.cuda.is_available() else "cpu"
245
+ parser.add_argument("--device", type=str, default=default_device, help="Device to use (e.g., cuda:0, cpu)")
246
+
247
+ # Task
248
+ parser.add_argument("--task", type=str, required=True, choices=["stt", "tts-a", "vc"], help="Task to run")
249
+
250
+ # TTS Inputs (Processor Arguments)
251
+ parser.add_argument("--text", type=str, default=None, help="Text for TTS")
252
+
253
+ return parser.parse_args()
254
+
255
+ def main():
256
+ args = parse_args()
257
+
258
+ # Ensure output directory exists
259
+ os.makedirs(args.output_dir, exist_ok=True)
260
+
261
+ inference = GPAInference(
262
+ tokenizer_path=args.tokenizer_path,
263
+ text_tokenizer_path=args.text_tokenizer_path,
264
+ bicodec_tokenizer_path=args.bicodec_tokenizer_path,
265
+ gpa_model_path=args.gpa_model_path,
266
+ output_dir=args.output_dir,
267
+ device=args.device,
268
+ )
269
+
270
+ if args.task == "stt":
271
+ if not args.src_audio_path:
272
+ raise ValueError("Error: --src_audio_path is required for STT task.")
273
+ # Pass gen_kwargs
274
+ result = inference.run_stt(audio_path=args.src_audio_path)
275
+ print("STT Result:", result)
276
+
277
+ elif args.task == "tts-a":
278
+ inference.run_tts(
279
+ task="tts-a",
280
+ output_filename="output_gpa_tts_a.wav",
281
+ text=args.text,
282
+ ref_audio_path=args.ref_audio_path,
283
+ )
284
+
285
+ elif args.task == "vc":
286
+ inference.run_vc(
287
+ source_audio_path=args.src_audio_path,
288
+ ref_audio_path=args.ref_audio_path,
289
+ output_filename="output_gpa_vc.wav",
290
+ )
291
+
292
+ if __name__ == "__main__":
293
+ main()
models/__init__.py ADDED
File without changes
models/bicodec_tokenizer/__init__.py ADDED
File without changes
models/bicodec_tokenizer/base_model.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Time :2025/3/29 10:28
3
+ # Author :Hui Huang
4
+ import json
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import yaml
9
+
10
+ from .tokenizer_utils import load_config
11
+ import os
12
+ from safetensors.torch import load_file
13
+
14
+
15
+ class SparkBaseModel(nn.Module):
16
+ @classmethod
17
+ def from_pretrained(cls, model_path: str):
18
+ config = load_config(os.path.join(model_path, "config.yaml"))['audio_tokenizer']
19
+ model = cls(config)
20
+ state_dict = load_file(os.path.join(model_path, "model.safetensors"))
21
+ model.load_state_dict(state_dict, strict=False)
22
+ model.eval()
23
+ model.remove_weight_norm()
24
+ return model
25
+
26
+ def remove_weight_norm(self):
27
+ """Removes weight normalization from all layers."""
28
+
29
+ def _remove_weight_norm(m):
30
+ try:
31
+ torch.nn.utils.remove_weight_norm(m)
32
+ except ValueError:
33
+ pass # The module didn't have weight norm
34
+
35
+ self.apply(_remove_weight_norm)
36
+
37
+
38
+ class SnacBaseModel(nn.Module):
39
+ @classmethod
40
+ def from_config(cls, config_path):
41
+ with open(config_path, "r") as f:
42
+ config = json.load(f)
43
+ model = cls(**config)
44
+ return model
45
+
46
+ @classmethod
47
+ def from_pretrained(cls, model_path: str):
48
+ model = cls.from_config(os.path.join(model_path, "config.json"))
49
+ state_dict = torch.load(
50
+ os.path.join(model_path, "pytorch_model.bin"),
51
+ map_location="cpu", weights_only=True)
52
+ model.load_state_dict(state_dict, strict=False)
53
+ model.eval()
54
+ return model
55
+
56
+
57
+ class MegaBaseModel(nn.Module):
58
+ CKPT_NAME = "model"
59
+
60
+ @classmethod
61
+ def from_pretrained(cls, model_path: str):
62
+ config_file = None
63
+ ckpt_path = None
64
+ for file in os.listdir(model_path):
65
+ if file.endswith(".ckpt"):
66
+ ckpt_path = os.path.join(model_path, file)
67
+ if file.endswith(".yaml"):
68
+ config_file = os.path.join(model_path, file)
69
+ if ckpt_path is None:
70
+ raise FileNotFoundError(f"No checkpoint found at {model_path}")
71
+
72
+ checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
73
+ state_dict_all = {
74
+ k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in checkpoint["state_dict"].items()
75
+ }
76
+ state_dict = state_dict_all[cls.CKPT_NAME]
77
+ state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in state_dict.items()}
78
+
79
+ if config_file is not None:
80
+ with open(config_file) as f:
81
+ config = yaml.safe_load(f)
82
+ model = cls(config)
83
+ else:
84
+ model = cls()
85
+ model.load_state_dict(state_dict, strict=False)
86
+ model.eval()
87
+ return model
models/bicodec_tokenizer/batch_processor.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Time :2024/11/17 15:33
3
+ # Author :Hui Huang
4
+ import asyncio
5
+ import uuid
6
+ from typing import Callable, List, Any, Awaitable, Tuple
7
+ from asyncio import Queue
8
+
9
+
10
+ class BatchProcessor:
11
+ """Batch Processor for handling asynchronous requests in batches.
12
+
13
+ This class manages a queue of requests and processes them in batches
14
+ using multiple worker tasks.
15
+
16
+ Attributes:
17
+ processing_function (Callable[[List[Any]], Awaitable[List[Any]]]):
18
+ The function used for processing requests in batches.
19
+ num_workers (int): The number of worker tasks to process requests.
20
+ batch_size (int): The maximum number of requests to process in a single batch.
21
+ request_queue (Queue): The queue holding incoming requests.
22
+ loop (asyncio.AbstractEventLoop): The event loop used to create worker tasks.
23
+ worker_tasks (List[asyncio.Task]): The list of worker tasks.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ processing_function: Callable[[List[Any]], Awaitable[List[Any]]],
29
+ num_workers: int,
30
+ batch_size: int,
31
+ wait_timeout: float = 0.05
32
+ ) -> None:
33
+ """Initialize the BatchProcessor with the given processing function, number of workers, and batch size.
34
+
35
+ Args:
36
+ processing_function (Callable[[List[Any]], Awaitable[List[Any]]]):
37
+ The function used for processing requests in batches.
38
+ num_workers (int): The number of worker tasks to process requests.
39
+ batch_size (int): The maximum number of requests to process in a single batch.
40
+ """
41
+ self.processing_function = processing_function
42
+ self.num_workers = num_workers
43
+ self.batch_size = batch_size
44
+ self.wait_timeout = wait_timeout
45
+ self.request_queue: Queue = Queue()
46
+ self.loop = asyncio.get_running_loop()
47
+ self.worker_tasks = [
48
+ self.loop.create_task(self.batch_processor(i)) for i in range(num_workers)
49
+ ]
50
+ # Wait until all worker tasks are started
51
+ self.loop.create_task(self._log_workers_started())
52
+
53
+ async def _log_workers_started(self):
54
+ await asyncio.sleep(0) # Yield control to ensure workers have started
55
+
56
+ async def batch_processor(self, worker_id: int):
57
+ """Worker task that processes requests from the queue in batches.
58
+
59
+ Args:
60
+ worker_id (int): The identifier for the worker task.
61
+ """
62
+
63
+ while True:
64
+ requests: List[Tuple[Any, asyncio.Future]] = []
65
+ try:
66
+ while len(requests) < self.batch_size:
67
+ request = await asyncio.wait_for(
68
+ self.request_queue.get(), timeout=self.wait_timeout
69
+ )
70
+ requests.append(request)
71
+ except asyncio.TimeoutError:
72
+ pass
73
+
74
+ if requests:
75
+ all_requests = [
76
+ req[0] for req in requests
77
+ ] # Extract the actual input data from each request tuple
78
+ futures = [req[1] for req in requests] # Extract the futures to resolve
79
+ try:
80
+ results = await self.processing_function(all_requests)
81
+
82
+ for (future, result) in zip(futures, results):
83
+ future.set_result(result)
84
+ except Exception as e:
85
+ for future in futures:
86
+ future.set_exception(e)
87
+
88
+ async def add_request(self, single_input: Any):
89
+ """Add a new request to the queue.
90
+
91
+ Args:
92
+ single_input (Any): The input data for processing.
93
+ """
94
+ # loop = asyncio.get_running_loop()
95
+ future = self.loop.create_future()
96
+ self.request_queue.put_nowait((single_input, future))
97
+ return future
98
+
99
+ async def shutdown(self):
100
+ """Shutdown the batch processor by cancelling all worker tasks."""
101
+ for task in self.worker_tasks:
102
+ task.cancel()
103
+ try:
104
+ await task
105
+ except asyncio.CancelledError:
106
+ print("Worker task cancelled.")
107
+
108
+
109
+ class AsyncBatchEngine:
110
+
111
+ def __init__(
112
+ self,
113
+ processing_function: Callable[[List[Any]], Awaitable[List[Any]]],
114
+ batch_size: int = 32,
115
+ wait_timeout: float = 0.01,
116
+ ):
117
+ """
118
+ Initialize the AsyncBatchEngine with a processing function, number of workers, and batch size.
119
+
120
+ Args:
121
+ processing_function (Callable[[List[Any]], Awaitable[List[Any]]]): The batch processing function.
122
+ batch_size (int): The maximum number of requests to process in a single batch.
123
+ """
124
+ self._processing_function = processing_function
125
+ self._batch_size = batch_size
126
+ self._is_running = False
127
+ self._batch_processor = None
128
+ self._wait_timeout = wait_timeout
129
+
130
+ async def start(self):
131
+ """Start the engine by initializing the batch processor and worker tasks."""
132
+ if self._is_running:
133
+ return
134
+
135
+ self._batch_processor = BatchProcessor(
136
+ processing_function=self._processing_function,
137
+ batch_size=self._batch_size,
138
+ wait_timeout=self._wait_timeout,
139
+ num_workers=1
140
+ )
141
+ self._is_running = True
142
+
143
+ async def stop(self):
144
+ """Stop the engine by shutting down the batch processor and worker tasks."""
145
+ self._check_running()
146
+ self._is_running = False
147
+ if self._batch_processor is not None:
148
+ await self._batch_processor.shutdown()
149
+
150
+ def _check_running(self):
151
+ """Check if the engine is running.
152
+
153
+ Raises:
154
+ ValueError: If the engine is not running.
155
+ """
156
+ if not self._is_running:
157
+ raise ValueError(
158
+ "The engine is not running. "
159
+ "You must start the engine before using it."
160
+ )
161
+
162
+ async def add_request(self, single_input: Any, request_id: str = None) -> dict:
163
+ """Asynchronously add a request to be processed.
164
+
165
+ Args:
166
+ single_input (Any): The input data for processing.
167
+ request_id (str): Optional request identifier to avoid data mix-up.
168
+
169
+ Raises:
170
+ ValueError: If the engine is not running when this method is called.
171
+ """
172
+ if not self._is_running:
173
+ await self.start()
174
+
175
+ if request_id is None:
176
+ request_id = str(uuid.uuid4()) # Assign a unique ID if not provided
177
+ future = await self._batch_processor.add_request(single_input=single_input) # type: ignore
178
+ result = await future
179
+ return dict(
180
+ request_id=request_id,
181
+ feature=result
182
+ )
models/bicodec_tokenizer/models/__init__.py ADDED
File without changes
models/bicodec_tokenizer/models/audio_tokenizer.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang (w.xinshawn@gmail.com)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import sys
17
+ sys.path.append("../..")
18
+ import torch
19
+ import numpy as np
20
+
21
+ from pathlib import Path
22
+ from typing import Any, Dict, Tuple
23
+ from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
24
+
25
+ from arktts.models.sparktts.utils.file import load_config
26
+ from arktts.models.sparktts.utils.audio import load_audio
27
+ from arktts.models.sparktts.models.bicodec import BiCodec
28
+
29
+
30
+ class BiCodecTokenizer:
31
+ """BiCodec tokenizer for handling audio input and tokenization."""
32
+
33
+ def __init__(self, model_dir: Path, device: torch.device = None, **kwargs):
34
+ super().__init__()
35
+ """
36
+ Args:
37
+ model_dir: Path to the model directory.
38
+ device: Device to run the model on (default is GPU if available).
39
+ """
40
+ self.device = device
41
+ self.model_dir = model_dir
42
+ self.config = load_config(f"{model_dir}/config.yaml")
43
+ self._initialize_model()
44
+
45
+ def _initialize_model(self):
46
+ """Load and initialize the BiCodec model and Wav2Vec2 feature extractor."""
47
+ self.model = BiCodec.load_from_checkpoint(f"{self.model_dir}/BiCodec").to(
48
+ self.device
49
+ )
50
+ self.processor = Wav2Vec2FeatureExtractor.from_pretrained(
51
+ f"{self.model_dir}/wav2vec2-large-xlsr-53"
52
+ )
53
+ self.feature_extractor = Wav2Vec2Model.from_pretrained(
54
+ f"{self.model_dir}/wav2vec2-large-xlsr-53"
55
+ ).to(self.device)
56
+ self.feature_extractor.config.output_hidden_states = True
57
+
58
+ def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
59
+ """Get reference audio clip for speaker embedding."""
60
+ ref_segment_length = (
61
+ int(self.config["sample_rate"] * self.config["ref_segment_duration"])
62
+ // self.config["latent_hop_length"]
63
+ * self.config["latent_hop_length"]
64
+ )
65
+ wav_length = len(wav)
66
+
67
+ if ref_segment_length > wav_length:
68
+ # Repeat and truncate to handle insufficient length
69
+ wav = np.tile(wav, ref_segment_length // wav_length + 1)
70
+
71
+ return wav[:ref_segment_length]
72
+
73
+ def process_audio(self, wav_path: Path) -> Tuple[np.ndarray, torch.Tensor]:
74
+ """load auido and get reference audio from wav path"""
75
+ wav = load_audio(
76
+ wav_path,
77
+ sampling_rate=self.config["sample_rate"],
78
+ volume_normalize=self.config["volume_normalize"],
79
+ )
80
+
81
+ wav_ref = self.get_ref_clip(wav)
82
+
83
+ wav_ref = torch.from_numpy(wav_ref).unsqueeze(0).float()
84
+ return wav, wav_ref
85
+
86
+ def extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor:
87
+ """extract wav2vec2 features"""
88
+ inputs = self.processor(
89
+ wavs,
90
+ sampling_rate=16000,
91
+ return_tensors="pt",
92
+ padding=True,
93
+ output_hidden_states=True,
94
+ ).input_values
95
+ feat = self.feature_extractor(inputs.to(self.feature_extractor.device))
96
+ feats_mix = (
97
+ feat.hidden_states[11] + feat.hidden_states[14] + feat.hidden_states[16]
98
+ ) / 3
99
+
100
+ return feats_mix
101
+
102
+ def tokenize_batch(self, batch: Dict[str, Any]) -> torch.Tensor:
103
+ """tokenize the batch of audio
104
+
105
+ Args:
106
+ batch:
107
+ wavs (List[np.ndarray]): batch of audio
108
+ ref_wavs (torch.Tensor): reference audio. shape: (batch_size, seq_len)
109
+
110
+ Returns:
111
+ semantic_tokens: semantic tokens. shape: (batch_size, seq_len, latent_dim)
112
+ global_tokens: global tokens. shape: (batch_size, seq_len, global_dim)
113
+ """
114
+ feats = self.extract_wav2vec2_features(batch["wav"])
115
+ batch["feat"] = feats
116
+ semantic_tokens, global_tokens = self.model.tokenize(batch)
117
+
118
+ return global_tokens, semantic_tokens
119
+
120
+ def tokenize(self, audio_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
121
+ """tokenize the audio"""
122
+ wav, ref_wav = self.process_audio(audio_path)
123
+ feat = self.extract_wav2vec2_features(wav)
124
+ batch = {
125
+ "wav": torch.from_numpy(wav).unsqueeze(0).float().to(self.device),
126
+ "ref_wav": ref_wav.to(self.device),
127
+ "feat": feat.to(self.device),
128
+ }
129
+ semantic_tokens, global_tokens = self.model.tokenize(batch)
130
+
131
+ return global_tokens, semantic_tokens
132
+
133
+ def detokenize(
134
+ self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor
135
+ ) -> np.array:
136
+ """detokenize the tokens to waveform
137
+
138
+ Args:
139
+ global_tokens: global tokens. shape: (batch_size, global_dim)
140
+ semantic_tokens: semantic tokens. shape: (batch_size, latent_dim)
141
+
142
+ Returns:
143
+ wav_rec: waveform. shape: (batch_size, seq_len) for batch or (seq_len,) for single
144
+ """
145
+ global_tokens = global_tokens.unsqueeze(1)
146
+ wav_rec = self.model.detokenize(semantic_tokens, global_tokens)
147
+ return wav_rec.detach().squeeze().cpu().numpy()
148
+
149
+
150
+ # test
151
+ if __name__ == "__main__":
152
+ import soundfile as sf
153
+
154
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
155
+ tokenizer = BiCodecTokenizer(
156
+ model_dir="pretrained_models/Spark-TTS-0.5B",
157
+ device=device,
158
+ )
159
+ wav_path = "example/prompt_audio.wav"
160
+
161
+ global_tokens, semantic_tokens = tokenizer.tokenize(wav_path)
162
+
163
+ wav_rec = tokenizer.detokenize(global_tokens.squeeze(0), semantic_tokens)
164
+ sf.write("example/prompt_recon.wav", wav_rec, 16000)
models/bicodec_tokenizer/models/bicodec.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang (w.xinshawn@gmail.com)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import sys
16
+ sys.path.append("../..")
17
+ import torch
18
+ import torch.nn as nn
19
+ from pathlib import Path
20
+ from typing import Dict, Any
21
+ from omegaconf import DictConfig
22
+ from safetensors.torch import load_file
23
+
24
+ from ..utils.file import load_config
25
+ from ..modules.speaker.speaker_encoder import SpeakerEncoder
26
+ from ..modules.encoder_decoder.feat_encoder import Encoder
27
+ from ..modules.encoder_decoder.feat_decoder import Decoder
28
+ from ..modules.encoder_decoder.wave_generator import WaveGenerator
29
+ from ..modules.vq.factorized_vector_quantize import FactorizedVectorQuantize
30
+
31
+
32
+ class BiCodec(nn.Module):
33
+ """
34
+ BiCodec model for speech synthesis, incorporating a speaker encoder, feature encoder/decoder,
35
+ quantizer, and wave generator.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ mel_params: Dict[str, Any],
41
+ encoder: nn.Module,
42
+ decoder: nn.Module,
43
+ quantizer: nn.Module,
44
+ speaker_encoder: nn.Module,
45
+ prenet: nn.Module,
46
+ postnet: nn.Module,
47
+ **kwargs
48
+ ) -> None:
49
+ """
50
+ Initializes the BiCodec model with the required components.
51
+
52
+ Args:
53
+ mel_params (dict): Parameters for the mel-spectrogram transformer.
54
+ encoder (nn.Module): Encoder module.
55
+ decoder (nn.Module): Decoder module.
56
+ quantizer (nn.Module): Quantizer module.
57
+ speaker_encoder (nn.Module): Speaker encoder module.
58
+ prenet (nn.Module): Prenet network.
59
+ postnet (nn.Module): Postnet network.
60
+ """
61
+ super().__init__()
62
+ self.encoder = encoder
63
+ self.decoder = decoder
64
+ self.quantizer = quantizer
65
+ self.speaker_encoder = speaker_encoder
66
+ self.prenet = prenet
67
+ self.postnet = postnet
68
+ self.init_mel_transformer(mel_params)
69
+
70
+ @classmethod
71
+ def load_from_checkpoint(cls, model_dir: Path, **kwargs) -> "BiCodec":
72
+ """
73
+ Loads the model from a checkpoint.
74
+
75
+ Args:
76
+ model_dir (Path): Path to the model directory containing checkpoint and config.
77
+
78
+ Returns:
79
+ BiCodec: The initialized BiCodec model.
80
+ """
81
+ ckpt_path = f'{model_dir}/model.safetensors'
82
+ config = load_config(f'{model_dir}/config.yaml')['audio_tokenizer']
83
+ mel_params = config["mel_params"]
84
+ encoder = Encoder(**config["encoder"])
85
+ quantizer = FactorizedVectorQuantize(**config["quantizer"])
86
+ prenet = Decoder(**config["prenet"])
87
+ postnet = Decoder(**config["postnet"])
88
+ decoder = WaveGenerator(**config["decoder"])
89
+ speaker_encoder = SpeakerEncoder(**config["speaker_encoder"])
90
+
91
+ model = cls(
92
+ mel_params=mel_params,
93
+ encoder=encoder,
94
+ decoder=decoder,
95
+ quantizer=quantizer,
96
+ speaker_encoder=speaker_encoder,
97
+ prenet=prenet,
98
+ postnet=postnet,
99
+ )
100
+
101
+ state_dict = load_file(ckpt_path)
102
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
103
+
104
+ for key in missing_keys:
105
+ print(f"Missing tensor: {key}")
106
+ for key in unexpected_keys:
107
+ print(f"Unexpected tensor: {key}")
108
+
109
+ model.eval()
110
+ model.remove_weight_norm()
111
+
112
+ return model
113
+
114
+ def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
115
+ """
116
+ Performs a forward pass through the model.
117
+
118
+ Args:
119
+ batch (dict): A dictionary containing features, reference waveform, and target waveform.
120
+
121
+ Returns:
122
+ dict: A dictionary containing the reconstruction, features, and other metrics.
123
+ """
124
+ feat = batch["feat"]
125
+ mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)
126
+
127
+ z = self.encoder(feat.transpose(1, 2))
128
+ vq_outputs = self.quantizer(z)
129
+
130
+ x_vector, d_vector = self.speaker_encoder(mel.transpose(1, 2))
131
+
132
+ conditions = d_vector
133
+ with_speaker_loss = False
134
+
135
+ x = self.prenet(vq_outputs["z_q"], conditions)
136
+ pred_feat = self.postnet(x)
137
+ x = x + conditions.unsqueeze(-1)
138
+ wav_recon = self.decoder(x)
139
+
140
+ return {
141
+ "vq_loss": vq_outputs["vq_loss"],
142
+ "perplexity": vq_outputs["perplexity"],
143
+ "cluster_size": vq_outputs["active_num"],
144
+ "recons": wav_recon,
145
+ "pred_feat": pred_feat,
146
+ "x_vector": x_vector,
147
+ "d_vector": d_vector,
148
+ "audios": batch["wav"].unsqueeze(1),
149
+ "with_speaker_loss": with_speaker_loss,
150
+ }
151
+
152
+ @torch.no_grad()
153
+ def tokenize(self, batch: Dict[str, Any]):
154
+ """
155
+ Tokenizes the input audio into semantic and global tokens.
156
+
157
+ Args:
158
+ batch (dict): The input audio features and reference waveform.
159
+
160
+ Returns:
161
+ tuple: Semantic tokens and global tokens.
162
+ """
163
+ feat = batch["feat"]
164
+ mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)
165
+
166
+ z = self.encoder(feat.transpose(1, 2))
167
+ semantic_tokens = self.quantizer.tokenize(z)
168
+ global_tokens = self.speaker_encoder.tokenize(mel.transpose(1, 2))
169
+
170
+ return semantic_tokens, global_tokens
171
+
172
+ @torch.no_grad()
173
+ def detokenize(self, semantic_tokens, global_tokens):
174
+ """
175
+ Detokenizes the semantic and global tokens into a waveform.
176
+
177
+ Args:
178
+ semantic_tokens (tensor): Semantic tokens.
179
+ global_tokens (tensor): Global tokens.
180
+
181
+ Returns:
182
+ tensor: Reconstructed waveform.
183
+ """
184
+ z_q = self.quantizer.detokenize(semantic_tokens)
185
+ d_vector = self.speaker_encoder.detokenize(global_tokens)
186
+ x = self.prenet(z_q, d_vector)
187
+ x = x + d_vector.unsqueeze(-1)
188
+ wav_recon = self.decoder(x)
189
+
190
+ return wav_recon
191
+
192
+ def init_mel_transformer(self, config: Dict[str, Any]):
193
+ """
194
+ Initializes the MelSpectrogram transformer based on the provided configuration.
195
+
196
+ Args:
197
+ config (dict): Configuration parameters for MelSpectrogram.
198
+ """
199
+ import torchaudio.transforms as TT
200
+
201
+ self.mel_transformer = TT.MelSpectrogram(
202
+ config["sample_rate"],
203
+ config["n_fft"],
204
+ config["win_length"],
205
+ config["hop_length"],
206
+ config["mel_fmin"],
207
+ config["mel_fmax"],
208
+ n_mels=config["num_mels"],
209
+ power=1,
210
+ norm="slaney",
211
+ mel_scale="slaney",
212
+ )
213
+
214
+ def remove_weight_norm(self):
215
+ """Removes weight normalization from all layers."""
216
+ def _remove_weight_norm(m):
217
+ try:
218
+ torch.nn.utils.remove_weight_norm(m)
219
+ except ValueError:
220
+ pass # The module didn't have weight norm
221
+
222
+ self.apply(_remove_weight_norm)
223
+
224
+
225
+ # Test the model
226
+ if __name__ == "__main__":
227
+
228
+ config = load_config("pretrained_models/SparkTTS-0.5B/BiCodec/config.yaml")
229
+ model = BiCodec.load_from_checkpoint(
230
+ model_dir="pretrained_models/SparkTTS-0.5B/BiCodec",
231
+ )
232
+
233
+ # Generate random inputs for testing
234
+ duration = 0.96
235
+ x = torch.randn(20, 1, int(duration * 16000))
236
+ feat = torch.randn(20, int(duration * 50), 1024)
237
+ inputs = {"feat": feat, "wav": x, "ref_wav": x}
238
+
239
+ # Forward pass
240
+ outputs = model(inputs)
241
+ semantic_tokens, global_tokens = model.tokenize(inputs)
242
+ wav_recon = model.detokenize(semantic_tokens, global_tokens)
243
+
244
+ # Verify if the reconstruction matches
245
+ if torch.allclose(outputs["recons"].detach(), wav_recon):
246
+ print("Test successful")
247
+ else:
248
+ print("Test failed")
models/bicodec_tokenizer/modules/blocks/layers.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang (w.xinshawn@gmail.com)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0
17
+
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.nn.utils import weight_norm
22
+
23
+
24
+ def WNConv1d(*args, **kwargs):
25
+ return weight_norm(nn.Conv1d(*args, **kwargs))
26
+
27
+
28
+ def WNConvTranspose1d(*args, **kwargs):
29
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
30
+
31
+
32
+ # Scripting this brings model speed up 1.4x
33
+ @torch.jit.script
34
+ def snake(x, alpha):
35
+ shape = x.shape
36
+ x = x.reshape(shape[0], shape[1], -1)
37
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
38
+ x = x.reshape(shape)
39
+ return x
40
+
41
+
42
+ class Snake1d(nn.Module):
43
+ def __init__(self, channels):
44
+ super().__init__()
45
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
46
+
47
+ def forward(self, x):
48
+ return snake(x, self.alpha)
49
+
50
+
51
+ class ResidualUnit(nn.Module):
52
+ def __init__(self, dim: int = 16, dilation: int = 1):
53
+ super().__init__()
54
+ pad = ((7 - 1) * dilation) // 2
55
+ self.block = nn.Sequential(
56
+ Snake1d(dim),
57
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
58
+ Snake1d(dim),
59
+ WNConv1d(dim, dim, kernel_size=1),
60
+ )
61
+
62
+ def forward(self, x):
63
+ y = self.block(x)
64
+ pad = (x.shape[-1] - y.shape[-1]) // 2
65
+ if pad > 0:
66
+ x = x[..., pad:-pad]
67
+ return x + y
68
+
69
+
70
+ def init_weights(m):
71
+ if isinstance(m, nn.Conv1d):
72
+ nn.init.trunc_normal_(m.weight, std=0.02)
73
+ nn.init.constant_(m.bias, 0)
models/bicodec_tokenizer/modules/blocks/samper.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang (w.xinshawn@gmail.com)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+
22
+ class SamplingBlock(nn.Module):
23
+ """Sampling block for upsampling or downsampling"""
24
+
25
+ def __init__(
26
+ self,
27
+ dim: int,
28
+ groups: int = 1,
29
+ upsample_scale: int = 1,
30
+ downsample_scale: int = 1,
31
+ ) -> None:
32
+ """
33
+ Args:
34
+ dim: input dimension
35
+ groups: number of groups
36
+ upsample_scale: upsampling scale
37
+ downsample_scale: downsampling scale
38
+ """
39
+ super(SamplingBlock, self).__init__()
40
+
41
+ self.upsample_scale = upsample_scale
42
+ self.downsample_scale = downsample_scale
43
+
44
+ if self.upsample_scale > 1:
45
+ self.de_conv_upsampler = nn.Sequential(
46
+ nn.LeakyReLU(0.2),
47
+ nn.ConvTranspose1d(
48
+ dim,
49
+ dim,
50
+ kernel_size=upsample_scale * 2,
51
+ stride=upsample_scale,
52
+ padding=upsample_scale // 2 + upsample_scale % 2,
53
+ output_padding=upsample_scale % 2,
54
+ groups=groups,
55
+ ),
56
+ )
57
+
58
+ if self.downsample_scale > 1:
59
+ self.conv_downsampler = nn.Sequential(
60
+ nn.LeakyReLU(0.2),
61
+ nn.Conv1d(
62
+ dim,
63
+ dim,
64
+ kernel_size=2 * downsample_scale,
65
+ stride=downsample_scale,
66
+ padding=downsample_scale // 2 + downsample_scale % 2,
67
+ groups=groups,
68
+ ),
69
+ )
70
+
71
+ @staticmethod
72
+ def repeat_upsampler(x, upsample_scale):
73
+ return x.repeat_interleave(upsample_scale, dim=2)
74
+
75
+ @staticmethod
76
+ def skip_downsampler(x, downsample_scale):
77
+ return F.avg_pool1d(x, kernel_size=downsample_scale, stride=downsample_scale)
78
+
79
+ def forward(self, x):
80
+ x = x.transpose(1, 2)
81
+ if self.upsample_scale > 1:
82
+ repeat_res = self.repeat_upsampler(x, self.upsample_scale)
83
+ deconv_res = self.de_conv_upsampler(x)
84
+ upmerge_res = repeat_res + deconv_res
85
+ else:
86
+ upmerge_res = x
87
+ repeat_res = x
88
+
89
+ if self.downsample_scale > 1:
90
+ conv_res = self.conv_downsampler(upmerge_res)
91
+ skip2_res = self.skip_downsampler(upmerge_res, self.downsample_scale)
92
+ skip1_res = self.skip_downsampler(repeat_res, self.downsample_scale)
93
+ else:
94
+ conv_res = upmerge_res
95
+ skip2_res = upmerge_res
96
+ skip1_res = repeat_res
97
+
98
+ final_res = conv_res + skip1_res + skip2_res
99
+
100
+ return final_res
101
+
102
+
103
+ # test
104
+ if __name__ == "__main__":
105
+ test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
106
+ model = SamplingBlock(1024, 1024, upsample_scale=2)
107
+ model_down = SamplingBlock(1024, 1024, downsample_scale=2)
108
+ output = model(test_input)
109
+ output_down = model_down(test_input)
110
+ print("shape after upsample * 2", output.shape) # torch.Size([8, 1024, 100])
111
+ print("shape after downsample * 2", output_down.shape) # torch.Size([8, 1024, 25])
112
+ if output.shape == torch.Size([8, 1024, 100]) and output_down.shape == torch.Size(
113
+ [8, 1024, 25]
114
+ ):
115
+ print("test successful")
models/bicodec_tokenizer/modules/blocks/vocos.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang (w.xinshawn@gmail.com)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from typing import Tuple
21
+ from torch.nn.utils import weight_norm, remove_weight_norm
22
+
23
+ from typing import Optional
24
+
25
+
26
+ class ConvNeXtBlock(nn.Module):
27
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
28
+
29
+ Args:
30
+ dim (int): Number of input channels.
31
+ intermediate_dim (int): Dimensionality of the intermediate layer.
32
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
33
+ Defaults to None.
34
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
35
+ None means non-conditional LayerNorm. Defaults to None.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ dim: int,
41
+ intermediate_dim: int,
42
+ layer_scale_init_value: float,
43
+ condition_dim: Optional[int] = None,
44
+ ):
45
+ super().__init__()
46
+ self.dwconv = nn.Conv1d(
47
+ dim, dim, kernel_size=7, padding=3, groups=dim
48
+ ) # depthwise conv
49
+ self.adanorm = condition_dim is not None
50
+ if condition_dim:
51
+ self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
52
+ else:
53
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
54
+ self.pwconv1 = nn.Linear(
55
+ dim, intermediate_dim
56
+ ) # pointwise/1x1 convs, implemented with linear layers
57
+ self.act = nn.GELU()
58
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
59
+ self.gamma = (
60
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
61
+ if layer_scale_init_value > 0
62
+ else None
63
+ )
64
+
65
+ def forward(
66
+ self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
67
+ ) -> torch.Tensor:
68
+ residual = x
69
+ x = self.dwconv(x)
70
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
71
+ if self.adanorm:
72
+ assert cond_embedding_id is not None
73
+ x = self.norm(x, cond_embedding_id)
74
+ else:
75
+ x = self.norm(x)
76
+ x = self.pwconv1(x)
77
+ x = self.act(x)
78
+ x = self.pwconv2(x)
79
+ if self.gamma is not None:
80
+ x = self.gamma * x
81
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
82
+
83
+ x = residual + x
84
+ return x
85
+
86
+
87
+ class AdaLayerNorm(nn.Module):
88
+ """
89
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
90
+
91
+ Args:
92
+ condition_dim (int): Dimension of the condition.
93
+ embedding_dim (int): Dimension of the embeddings.
94
+ """
95
+
96
+ def __init__(self, condition_dim: int, embedding_dim: int, eps: float = 1e-6):
97
+ super().__init__()
98
+ self.eps = eps
99
+ self.dim = embedding_dim
100
+ self.scale = nn.Linear(condition_dim, embedding_dim)
101
+ self.shift = nn.Linear(condition_dim, embedding_dim)
102
+ torch.nn.init.ones_(self.scale.weight)
103
+ torch.nn.init.zeros_(self.shift.weight)
104
+
105
+ def forward(self, x: torch.Tensor, cond_embedding: torch.Tensor) -> torch.Tensor:
106
+ scale = self.scale(cond_embedding)
107
+ shift = self.shift(cond_embedding)
108
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
109
+ x = x * scale.unsqueeze(1) + shift.unsqueeze(1)
110
+ return x
111
+
112
+
113
+ class ResBlock1(nn.Module):
114
+ """
115
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
116
+ but without upsampling layers.
117
+
118
+ Args:
119
+ dim (int): Number of input channels.
120
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
121
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
122
+ Defaults to (1, 3, 5).
123
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
124
+ Defaults to 0.1.
125
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
126
+ Defaults to None.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ dim: int,
132
+ kernel_size: int = 3,
133
+ dilation: Tuple[int, int, int] = (1, 3, 5),
134
+ lrelu_slope: float = 0.1,
135
+ layer_scale_init_value: Optional[float] = None,
136
+ ):
137
+ super().__init__()
138
+ self.lrelu_slope = lrelu_slope
139
+ self.convs1 = nn.ModuleList(
140
+ [
141
+ weight_norm(
142
+ nn.Conv1d(
143
+ dim,
144
+ dim,
145
+ kernel_size,
146
+ 1,
147
+ dilation=dilation[0],
148
+ padding=self.get_padding(kernel_size, dilation[0]),
149
+ )
150
+ ),
151
+ weight_norm(
152
+ nn.Conv1d(
153
+ dim,
154
+ dim,
155
+ kernel_size,
156
+ 1,
157
+ dilation=dilation[1],
158
+ padding=self.get_padding(kernel_size, dilation[1]),
159
+ )
160
+ ),
161
+ weight_norm(
162
+ nn.Conv1d(
163
+ dim,
164
+ dim,
165
+ kernel_size,
166
+ 1,
167
+ dilation=dilation[2],
168
+ padding=self.get_padding(kernel_size, dilation[2]),
169
+ )
170
+ ),
171
+ ]
172
+ )
173
+
174
+ self.convs2 = nn.ModuleList(
175
+ [
176
+ weight_norm(
177
+ nn.Conv1d(
178
+ dim,
179
+ dim,
180
+ kernel_size,
181
+ 1,
182
+ dilation=1,
183
+ padding=self.get_padding(kernel_size, 1),
184
+ )
185
+ ),
186
+ weight_norm(
187
+ nn.Conv1d(
188
+ dim,
189
+ dim,
190
+ kernel_size,
191
+ 1,
192
+ dilation=1,
193
+ padding=self.get_padding(kernel_size, 1),
194
+ )
195
+ ),
196
+ weight_norm(
197
+ nn.Conv1d(
198
+ dim,
199
+ dim,
200
+ kernel_size,
201
+ 1,
202
+ dilation=1,
203
+ padding=self.get_padding(kernel_size, 1),
204
+ )
205
+ ),
206
+ ]
207
+ )
208
+
209
+ self.gamma = nn.ParameterList(
210
+ [
211
+ (
212
+ nn.Parameter(
213
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
214
+ )
215
+ if layer_scale_init_value is not None
216
+ else None
217
+ ),
218
+ (
219
+ nn.Parameter(
220
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
221
+ )
222
+ if layer_scale_init_value is not None
223
+ else None
224
+ ),
225
+ (
226
+ nn.Parameter(
227
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
228
+ )
229
+ if layer_scale_init_value is not None
230
+ else None
231
+ ),
232
+ ]
233
+ )
234
+
235
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
236
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
237
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
238
+ xt = c1(xt)
239
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
240
+ xt = c2(xt)
241
+ if gamma is not None:
242
+ xt = gamma * xt
243
+ x = xt + x
244
+ return x
245
+
246
+ def remove_weight_norm(self):
247
+ for l in self.convs1:
248
+ remove_weight_norm(l)
249
+ for l in self.convs2:
250
+ remove_weight_norm(l)
251
+
252
+ @staticmethod
253
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
254
+ return int((kernel_size * dilation - dilation) / 2)
255
+
256
+
257
+ class Backbone(nn.Module):
258
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
259
+
260
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
261
+ """
262
+ Args:
263
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
264
+ C denotes output features, and L is the sequence length.
265
+
266
+ Returns:
267
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
268
+ and H denotes the model dimension.
269
+ """
270
+ raise NotImplementedError("Subclasses must implement the forward method.")
271
+
272
+
273
+ class VocosBackbone(Backbone):
274
+ """
275
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
276
+
277
+ Args:
278
+ input_channels (int): Number of input features channels.
279
+ dim (int): Hidden dimension of the model.
280
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
281
+ num_layers (int): Number of ConvNeXtBlock layers.
282
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
283
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
284
+ None means non-conditional model. Defaults to None.
285
+ """
286
+
287
+ def __init__(
288
+ self,
289
+ input_channels: int,
290
+ dim: int,
291
+ intermediate_dim: int,
292
+ num_layers: int,
293
+ layer_scale_init_value: Optional[float] = None,
294
+ condition_dim: Optional[int] = None,
295
+ ):
296
+ super().__init__()
297
+ self.input_channels = input_channels
298
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
299
+ self.adanorm = condition_dim is not None
300
+ if condition_dim:
301
+ self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
302
+ else:
303
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
304
+ layer_scale_init_value = layer_scale_init_value or 1 / num_layers
305
+ self.convnext = nn.ModuleList(
306
+ [
307
+ ConvNeXtBlock(
308
+ dim=dim,
309
+ intermediate_dim=intermediate_dim,
310
+ layer_scale_init_value=layer_scale_init_value,
311
+ condition_dim=condition_dim,
312
+ )
313
+ for _ in range(num_layers)
314
+ ]
315
+ )
316
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
317
+ self.apply(self._init_weights)
318
+
319
+ def _init_weights(self, m):
320
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
321
+ nn.init.trunc_normal_(m.weight, std=0.02)
322
+ nn.init.constant_(m.bias, 0)
323
+
324
+ def forward(self, x: torch.Tensor, condition: torch.Tensor = None) -> torch.Tensor:
325
+ x = self.embed(x)
326
+ if self.adanorm:
327
+ assert condition is not None
328
+ x = self.norm(x.transpose(1, 2), condition)
329
+ else:
330
+ x = self.norm(x.transpose(1, 2))
331
+ x = x.transpose(1, 2)
332
+ for conv_block in self.convnext:
333
+ x = conv_block(x, condition)
334
+ x = self.final_layer_norm(x.transpose(1, 2))
335
+ return x
336
+
337
+
338
+ class VocosResNetBackbone(Backbone):
339
+ """
340
+ Vocos backbone module built with ResBlocks.
341
+
342
+ Args:
343
+ input_channels (int): Number of input features channels.
344
+ dim (int): Hidden dimension of the model.
345
+ num_blocks (int): Number of ResBlock1 blocks.
346
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
347
+ """
348
+
349
+ def __init__(
350
+ self,
351
+ input_channels,
352
+ dim,
353
+ num_blocks,
354
+ layer_scale_init_value=None,
355
+ ):
356
+ super().__init__()
357
+ self.input_channels = input_channels
358
+ self.embed = weight_norm(
359
+ nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
360
+ )
361
+ layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
362
+ self.resnet = nn.Sequential(
363
+ *[
364
+ ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
365
+ for _ in range(num_blocks)
366
+ ]
367
+ )
368
+
369
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
370
+ x = self.embed(x)
371
+ x = self.resnet(x)
372
+ x = x.transpose(1, 2)
373
+ return x
models/bicodec_tokenizer/modules/encoder_decoder/feat_decoder.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang (w.xinshawn@gmail.com)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from typing import List
21
+
22
+ from ..blocks.vocos import VocosBackbone
23
+ from ..blocks.samper import SamplingBlock
24
+
25
+
26
+ class Decoder(nn.Module):
27
+ """Decoder module with convnext and upsampling blocks
28
+
29
+ Args:
30
+ sample_ratios (List[int]): sample ratios
31
+ example: [2, 2] means downsample by 2x and then upsample by 2x
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ input_channels: int,
37
+ vocos_dim: int,
38
+ vocos_intermediate_dim: int,
39
+ vocos_num_layers: int,
40
+ out_channels: int,
41
+ condition_dim: int = None,
42
+ sample_ratios: List[int] = [1, 1],
43
+ use_tanh_at_final: bool = False,
44
+ ):
45
+ super().__init__()
46
+
47
+ self.linear_pre = nn.Linear(input_channels, vocos_dim)
48
+ modules = [
49
+ nn.Sequential(
50
+ SamplingBlock(
51
+ dim=vocos_dim,
52
+ groups=vocos_dim,
53
+ upsample_scale=ratio,
54
+ ),
55
+ VocosBackbone(
56
+ input_channels=vocos_dim,
57
+ dim=vocos_dim,
58
+ intermediate_dim=vocos_intermediate_dim,
59
+ num_layers=2,
60
+ condition_dim=None,
61
+ ),
62
+ )
63
+ for ratio in sample_ratios
64
+ ]
65
+
66
+ self.downsample = nn.Sequential(*modules)
67
+
68
+ self.vocos_backbone = VocosBackbone(
69
+ input_channels=vocos_dim,
70
+ dim=vocos_dim,
71
+ intermediate_dim=vocos_intermediate_dim,
72
+ num_layers=vocos_num_layers,
73
+ condition_dim=condition_dim,
74
+ )
75
+ self.linear = nn.Linear(vocos_dim, out_channels)
76
+ self.use_tanh_at_final = use_tanh_at_final
77
+
78
+ def forward(self, x: torch.Tensor, c: torch.Tensor = None):
79
+ """encoder forward.
80
+
81
+ Args:
82
+ x (torch.Tensor): (batch_size, input_channels, length)
83
+
84
+ Returns:
85
+ x (torch.Tensor): (batch_size, encode_channels, length)
86
+ """
87
+ x = self.linear_pre(x.transpose(1, 2))
88
+ x = self.downsample(x).transpose(1, 2)
89
+ x = self.vocos_backbone(x, condition=c)
90
+ x = self.linear(x).transpose(1, 2)
91
+ if self.use_tanh_at_final:
92
+ x = torch.tanh(x)
93
+
94
+ return x
95
+
96
+
97
+ # test
98
+ if __name__ == "__main__":
99
+ test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
100
+ condition = torch.randn(8, 256)
101
+ decoder = Decoder(
102
+ input_channels=1024,
103
+ vocos_dim=384,
104
+ vocos_intermediate_dim=2048,
105
+ vocos_num_layers=12,
106
+ out_channels=256,
107
+ condition_dim=256,
108
+ sample_ratios=[2, 2],
109
+ )
110
+ output = decoder(test_input, condition)
111
+ print(output.shape) # torch.Size([8, 256, 200])
112
+ if output.shape == torch.Size([8, 256, 200]):
113
+ print("Decoder test passed")
114
+ else:
115
+ print("Decoder test failed")
models/bicodec_tokenizer/modules/encoder_decoder/feat_encoder.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang (w.xinshawn@gmail.com)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from typing import List
21
+ import sys
22
+ sys.path.append("../../../..")
23
+ sys.path.append("../../../../..")
24
+ from ..blocks.vocos import VocosBackbone
25
+ from ..blocks.samper import SamplingBlock
26
+
27
+
28
+ class Encoder(nn.Module):
29
+ """Encoder module with convnext and downsampling blocks"""
30
+
31
+ def __init__(
32
+ self,
33
+ input_channels: int,
34
+ vocos_dim: int,
35
+ vocos_intermediate_dim: int,
36
+ vocos_num_layers: int,
37
+ out_channels: int,
38
+ sample_ratios: List[int] = [1, 1],
39
+ ):
40
+ super().__init__()
41
+ """
42
+ Encoder module with VocosBackbone and sampling blocks.
43
+
44
+ Args:
45
+ sample_ratios (List[int]): sample ratios
46
+ example: [2, 2] means downsample by 2x and then upsample by 2x
47
+ """
48
+ self.encoder = VocosBackbone(
49
+ input_channels=input_channels,
50
+ dim=vocos_dim,
51
+ intermediate_dim=vocos_intermediate_dim,
52
+ num_layers=vocos_num_layers,
53
+ condition_dim=None,
54
+ )
55
+
56
+ modules = [
57
+ nn.Sequential(
58
+ SamplingBlock(
59
+ dim=vocos_dim,
60
+ groups=vocos_dim,
61
+ downsample_scale=ratio,
62
+ ),
63
+ VocosBackbone(
64
+ input_channels=vocos_dim,
65
+ dim=vocos_dim,
66
+ intermediate_dim=vocos_intermediate_dim,
67
+ num_layers=2,
68
+ condition_dim=None,
69
+ ),
70
+ )
71
+ for ratio in sample_ratios
72
+ ]
73
+
74
+ self.downsample = nn.Sequential(*modules)
75
+
76
+ self.project = nn.Linear(vocos_dim, out_channels)
77
+
78
+ def forward(self, x: torch.Tensor, *args):
79
+ """
80
+ Args:
81
+ x (torch.Tensor): (batch_size, input_channels, length)
82
+
83
+ Returns:
84
+ x (torch.Tensor): (batch_size, encode_channels, length)
85
+ """
86
+ x = self.encoder(x)
87
+ x = self.downsample(x)
88
+ x = self.project(x)
89
+ return x.transpose(1, 2)
90
+
91
+
92
+ # test
93
+ if __name__ == "__main__":
94
+ test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
95
+ encoder = Encoder(
96
+ input_channels=1024,
97
+ vocos_dim=384,
98
+ vocos_intermediate_dim=2048,
99
+ vocos_num_layers=12,
100
+ out_channels=256,
101
+ sample_ratios=[2, 2],
102
+ )
103
+
104
+ output = encoder(test_input)
105
+ print(output.shape) # torch.Size([8, 256, 12])
106
+ if output.shape == torch.Size([8, 256, 12]):
107
+ print("test successful")
models/bicodec_tokenizer/modules/encoder_decoder/wave_generator.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Xinsheng Wang (w.xinshawn@gmail.com)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0
16
+
17
+
18
+ import torch.nn as nn
19
+
20
+ from ..blocks.layers import (
21
+ Snake1d,
22
+ WNConv1d,
23
+ ResidualUnit,
24
+ WNConvTranspose1d,
25
+ init_weights,
26
+ )
27
+
28
+
29
+ class DecoderBlock(nn.Module):
30
+ def __init__(
31
+ self,
32
+ input_dim: int = 16,
33
+ output_dim: int = 8,
34
+ kernel_size: int = 2,
35
+ stride: int = 1,
36
+ ):
37
+ super().__init__()
38
+ self.block = nn.Sequential(
39
+ Snake1d(input_dim),
40
+ WNConvTranspose1d(
41
+ input_dim,
42
+ output_dim,
43
+ kernel_size=kernel_size,
44
+ stride=stride,
45
+ padding=(kernel_size - stride) // 2,
46
+ ),
47
+ ResidualUnit(output_dim, dilation=1),
48
+ ResidualUnit(output_dim, dilation=3),
49
+ ResidualUnit(output_dim, dilation=9),
50
+ )
51
+
52
+ def forward(self, x):
53
+ return self.block(x)
54
+
55
+
56
+ class WaveGenerator(nn.Module):
57
+ def __init__(
58
+ self,
59
+ input_channel,
60
+ channels,
61
+ rates,
62
+ kernel_sizes,
63
+ d_out: int = 1,
64
+ ):
65
+ super().__init__()
66
+
67
+ # Add first conv layer
68
+ layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
69
+
70
+ # Add upsampling + MRF blocks
71
+ for i, (kernel_size, stride) in enumerate(zip(kernel_sizes, rates)):
72
+ input_dim = channels // 2**i
73
+ output_dim = channels // 2 ** (i + 1)
74
+ layers += [DecoderBlock(input_dim, output_dim, kernel_size, stride)]
75
+
76
+ # Add final conv layer
77
+ layers += [
78
+ Snake1d(output_dim),
79
+ WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
80
+ nn.Tanh(),
81
+ ]
82
+
83
+ self.model = nn.Sequential(*layers)
84
+
85
+ self.apply(init_weights)
86
+
87
+ def forward(self, x):
88
+ return self.model(x)
models/bicodec_tokenizer/modules/fsq/finite_scalar_quantization.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
3
+ Code adapted from Jax version in Appendix A.1
4
+ """
5
+
6
+ from __future__ import annotations
7
+ from functools import wraps, partial
8
+ from contextlib import nullcontext
9
+ from typing import List, Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import Module
14
+ from torch import Tensor, int32
15
+ from torch.amp import autocast
16
+
17
+ from einops import rearrange, pack, unpack
18
+
19
+ # helper functions
20
+
21
+
22
+ def exists(v):
23
+ return v is not None
24
+
25
+
26
+ def default(*args):
27
+ for arg in args:
28
+ if exists(arg):
29
+ return arg
30
+ return None
31
+
32
+
33
+ def maybe(fn):
34
+ @wraps(fn)
35
+ def inner(x, *args, **kwargs):
36
+ if not exists(x):
37
+ return x
38
+ return fn(x, *args, **kwargs)
39
+
40
+ return inner
41
+
42
+
43
+ def pack_one(t, pattern):
44
+ return pack([t], pattern)
45
+
46
+
47
+ def unpack_one(t, ps, pattern):
48
+ return unpack(t, ps, pattern)[0]
49
+
50
+
51
+ # tensor helpers
52
+
53
+
54
+ def round_ste(z: Tensor) -> Tensor:
55
+ """Round with straight through gradients."""
56
+ zhat = z.round()
57
+ return z + (zhat - z).detach()
58
+
59
+
60
+ # main class
61
+
62
+
63
+ class FSQ(Module):
64
+ def __init__(
65
+ self,
66
+ levels: List[int],
67
+ dim: int | None = None,
68
+ num_codebooks=1,
69
+ keep_num_codebooks_dim: bool | None = None,
70
+ scale: float | None = None,
71
+ allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
72
+ channel_first: bool = False,
73
+ projection_has_bias: bool = True,
74
+ return_indices=True,
75
+ force_quantization_f32=True,
76
+ ):
77
+ super().__init__()
78
+ _levels = torch.tensor(levels, dtype=int32)
79
+ self.register_buffer("_levels", _levels, persistent=False)
80
+
81
+ _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
82
+ self.register_buffer("_basis", _basis, persistent=False)
83
+
84
+ self.scale = scale
85
+
86
+ codebook_dim = len(levels)
87
+ self.codebook_dim = codebook_dim
88
+
89
+ effective_codebook_dim = codebook_dim * num_codebooks
90
+ self.num_codebooks = num_codebooks
91
+ self.effective_codebook_dim = effective_codebook_dim
92
+
93
+ keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
94
+ assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
95
+ self.keep_num_codebooks_dim = keep_num_codebooks_dim
96
+
97
+ self.dim = default(dim, len(_levels) * num_codebooks)
98
+
99
+ self.channel_first = channel_first
100
+
101
+ has_projections = self.dim != effective_codebook_dim
102
+ self.project_in = (
103
+ nn.Linear(self.dim, effective_codebook_dim, bias=projection_has_bias)
104
+ if has_projections
105
+ else nn.Identity()
106
+ )
107
+ self.project_out = (
108
+ nn.Linear(effective_codebook_dim, self.dim, bias=projection_has_bias)
109
+ if has_projections
110
+ else nn.Identity()
111
+ )
112
+
113
+ self.has_projections = has_projections
114
+
115
+ self.return_indices = return_indices
116
+ if return_indices:
117
+ self.codebook_size = self._levels.prod().item()
118
+ implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size))
119
+ self.register_buffer(
120
+ "implicit_codebook", implicit_codebook, persistent=False
121
+ )
122
+
123
+ self.allowed_dtypes = allowed_dtypes
124
+ self.force_quantization_f32 = force_quantization_f32
125
+
126
+ def bound(self, z, eps: float = 1e-3):
127
+ """Bound `z`, an array of shape (..., d)."""
128
+ half_l = (self._levels - 1) * (1 + eps) / 2
129
+ offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
130
+ shift = (offset / half_l).atanh()
131
+ return (z + shift).tanh() * half_l - offset
132
+
133
+ def quantize(self, z):
134
+ """Quantizes z, returns quantized zhat, same shape as z."""
135
+ quantized = round_ste(self.bound(z))
136
+ half_width = self._levels // 2 # Renormalize to [-1, 1].
137
+ return quantized / half_width
138
+
139
+ def _scale_and_shift(self, zhat_normalized):
140
+ half_width = self._levels // 2
141
+ return (zhat_normalized * half_width) + half_width
142
+
143
+ def _scale_and_shift_inverse(self, zhat):
144
+ half_width = self._levels // 2
145
+ return (zhat - half_width) / half_width
146
+
147
+ def _indices_to_codes(self, indices):
148
+ level_indices = self.indices_to_level_indices(indices)
149
+ codes = self._scale_and_shift_inverse(level_indices)
150
+ return codes
151
+
152
+ def codes_to_indices(self, zhat):
153
+ """Converts a `code` to an index in the codebook."""
154
+ assert zhat.shape[-1] == self.codebook_dim
155
+ zhat = self._scale_and_shift(zhat)
156
+ return (zhat * self._basis).sum(dim=-1).to(int32)
157
+
158
+ def indices_to_level_indices(self, indices):
159
+ """Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings"""
160
+ indices = rearrange(indices, "... -> ... 1")
161
+ codes_non_centered = (indices // self._basis) % self._levels
162
+ return codes_non_centered
163
+
164
+ def indices_to_codes(self, indices):
165
+ """Inverse of `codes_to_indices`."""
166
+ assert exists(indices)
167
+
168
+ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
169
+
170
+ codes = self._indices_to_codes(indices)
171
+
172
+ if self.keep_num_codebooks_dim:
173
+ codes = rearrange(codes, "... c d -> ... (c d)")
174
+
175
+ codes = self.project_out(codes)
176
+
177
+ if is_img_or_video or self.channel_first:
178
+ codes = rearrange(codes, "b ... d -> b d ...")
179
+
180
+ return codes
181
+
182
+ def forward(self, z):
183
+ """
184
+ einstein notation
185
+ b - batch
186
+ n - sequence (or flattened spatial dimensions)
187
+ d - feature dimension
188
+ c - number of codebook dim
189
+ """
190
+
191
+ is_img_or_video = z.ndim >= 4
192
+ need_move_channel_last = is_img_or_video or self.channel_first
193
+
194
+ # standardize image or video into (batch, seq, dimension)
195
+
196
+ if need_move_channel_last:
197
+ z = rearrange(z, "b d ... -> b ... d")
198
+ z, ps = pack_one(z, "b * d")
199
+
200
+ assert (
201
+ z.shape[-1] == self.dim
202
+ ), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
203
+
204
+ z = self.project_in(z)
205
+
206
+ z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
207
+
208
+ # whether to force quantization step to be full precision or not
209
+
210
+ force_f32 = self.force_quantization_f32
211
+ quantization_context = (
212
+ partial(autocast, "cuda", enabled=False) if force_f32 else nullcontext
213
+ )
214
+
215
+ with quantization_context():
216
+ orig_dtype = z.dtype
217
+
218
+ if force_f32 and orig_dtype not in self.allowed_dtypes:
219
+ z = z.float()
220
+
221
+ codes = self.quantize(z)
222
+
223
+ # returning indices could be optional
224
+
225
+ indices = None
226
+
227
+ if self.return_indices:
228
+ indices = self.codes_to_indices(codes)
229
+
230
+ codes = rearrange(codes, "b n c d -> b n (c d)")
231
+
232
+ codes = codes.type(orig_dtype)
233
+
234
+ # project out
235
+
236
+ out = self.project_out(codes)
237
+
238
+ # reconstitute image or video dimensions
239
+
240
+ if need_move_channel_last:
241
+ out = unpack_one(out, ps, "b * d")
242
+ out = rearrange(out, "b ... d -> b d ...")
243
+
244
+ indices = maybe(unpack_one)(indices, ps, "b * c")
245
+
246
+ if not self.keep_num_codebooks_dim and self.return_indices:
247
+ indices = maybe(rearrange)(indices, "... 1 -> ...")
248
+
249
+ # return quantized output and indices
250
+
251
+ return out, indices
models/bicodec_tokenizer/modules/fsq/residual_fsq.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torch.distributed as dist
5
+
6
+ from typing import List
7
+ from torch import nn
8
+ from torch.nn import Module
9
+ from torch.amp import autocast
10
+ from einx import get_at
11
+ from einops import rearrange, reduce, pack, unpack
12
+
13
+ from .finite_scalar_quantization import FSQ
14
+
15
+
16
+ def exists(val):
17
+ return val is not None
18
+
19
+
20
+ def first(l):
21
+ return l[0]
22
+
23
+
24
+ def default(val, d):
25
+ return val if exists(val) else d
26
+
27
+
28
+ def round_up_multiple(num, mult):
29
+ return ceil(num / mult) * mult
30
+
31
+
32
+ # distributed helpers
33
+
34
+
35
+ def is_distributed():
36
+ return dist.is_initialized() and dist.get_world_size() > 1
37
+
38
+
39
+ def get_maybe_sync_seed(device, max_size=10_000):
40
+ rand_int = torch.randint(0, max_size, (), device=device)
41
+
42
+ if is_distributed():
43
+ dist.all_reduce(rand_int)
44
+
45
+ return rand_int.item()
46
+
47
+
48
+ class ResidualFSQ(Module):
49
+ """Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
50
+
51
+ def __init__(
52
+ self,
53
+ *,
54
+ levels: List[int],
55
+ num_quantizers,
56
+ dim=None,
57
+ is_channel_first=False,
58
+ quantize_dropout=False,
59
+ quantize_dropout_cutoff_index=0,
60
+ quantize_dropout_multiple_of=1,
61
+ **kwargs,
62
+ ):
63
+ super().__init__()
64
+ codebook_dim = len(levels)
65
+ dim = default(dim, codebook_dim)
66
+
67
+ requires_projection = codebook_dim != dim
68
+ self.project_in = (
69
+ nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
70
+ )
71
+ self.project_out = (
72
+ nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
73
+ )
74
+ self.has_projections = requires_projection
75
+
76
+ self.is_channel_first = is_channel_first
77
+ self.num_quantizers = num_quantizers
78
+
79
+ self.levels = levels
80
+ self.layers = nn.ModuleList([])
81
+
82
+ levels_tensor = torch.Tensor(levels)
83
+
84
+ scales = []
85
+
86
+ for ind in range(num_quantizers):
87
+ scales.append((levels_tensor - 1) ** -ind)
88
+
89
+ fsq = FSQ(levels=levels, dim=codebook_dim, **kwargs)
90
+
91
+ self.layers.append(fsq)
92
+
93
+ assert all([not fsq.has_projections for fsq in self.layers])
94
+
95
+ self.codebook_size = self.layers[0].codebook_size
96
+
97
+ self.register_buffer("scales", torch.stack(scales), persistent=False)
98
+
99
+ self.quantize_dropout = quantize_dropout and num_quantizers > 1
100
+
101
+ assert quantize_dropout_cutoff_index >= 0
102
+
103
+ self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
104
+ self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
105
+
106
+ @property
107
+ def codebooks(self):
108
+ codebooks = [layer.implicit_codebook for layer in self.layers]
109
+ codebooks = torch.stack(codebooks, dim=0)
110
+ return codebooks
111
+
112
+ def get_codes_from_indices(self, indices):
113
+
114
+ batch, quantize_dim = indices.shape[0], indices.shape[-1]
115
+
116
+ # may also receive indices in the shape of 'b h w q' (accept_image_fmap)
117
+
118
+ indices, ps = pack([indices], "b * q")
119
+
120
+ # because of quantize dropout, one can pass in indices that are coarse
121
+ # and the network should be able to reconstruct
122
+
123
+ if quantize_dim < self.num_quantizers:
124
+ assert (
125
+ self.quantize_dropout > 0.0
126
+ ), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
127
+ indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)
128
+
129
+ # take care of quantizer dropout
130
+
131
+ mask = indices == -1
132
+ indices = indices.masked_fill(
133
+ mask, 0
134
+ ) # have it fetch a dummy code to be masked out later
135
+
136
+ all_codes = get_at("q [c] d, b n q -> q b n d", self.codebooks, indices)
137
+
138
+ # mask out any codes that were dropout-ed
139
+
140
+ all_codes = all_codes.masked_fill(rearrange(mask, "b n q -> q b n 1"), 0.0)
141
+
142
+ # scale the codes
143
+
144
+ scales = rearrange(self.scales, "q d -> q 1 1 d")
145
+ all_codes = all_codes * scales
146
+
147
+ # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
148
+
149
+ (all_codes,) = unpack(all_codes, ps, "q b * d")
150
+
151
+ return all_codes
152
+
153
+ def get_output_from_indices(self, indices):
154
+ codes = self.get_codes_from_indices(indices)
155
+ codes_summed = reduce(codes, "q ... -> ...", "sum")
156
+ return self.project_out(codes_summed)
157
+
158
+ def forward(self, x, return_all_codes=False, rand_quantize_dropout_fixed_seed=None):
159
+ num_quant, quant_dropout_multiple_of, device = (
160
+ self.num_quantizers,
161
+ self.quantize_dropout_multiple_of,
162
+ x.device,
163
+ )
164
+
165
+ # handle channel first
166
+
167
+ if self.is_channel_first:
168
+ x = rearrange(x, "b d ... -> b ... d")
169
+ x, ps = pack([x], "b * d")
170
+
171
+ # maybe project in
172
+
173
+ x = self.project_in(x)
174
+
175
+ quantized_out = 0.0
176
+ residual = x
177
+
178
+ all_indices = []
179
+
180
+ should_quantize_dropout = self.training and self.quantize_dropout
181
+
182
+ # sample a layer index at which to dropout further residual quantization
183
+ # also prepare null indices
184
+
185
+ if should_quantize_dropout:
186
+
187
+ # check if seed is manually passed in
188
+
189
+ if not exists(rand_quantize_dropout_fixed_seed):
190
+ rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
191
+
192
+ rand = random.Random(rand_quantize_dropout_fixed_seed)
193
+
194
+ rand_quantize_dropout_index = rand.randrange(
195
+ self.quantize_dropout_cutoff_index, num_quant
196
+ )
197
+
198
+ if quant_dropout_multiple_of != 1:
199
+ rand_quantize_dropout_index = (
200
+ round_up_multiple(
201
+ rand_quantize_dropout_index + 1, quant_dropout_multiple_of
202
+ )
203
+ - 1
204
+ )
205
+
206
+ null_indices = torch.full(
207
+ x.shape[:2], -1.0, device=device, dtype=torch.long
208
+ )
209
+
210
+ # go through the layers
211
+
212
+ with autocast("cuda", enabled=False):
213
+ for quantizer_index, (layer, scale) in enumerate(
214
+ zip(self.layers, self.scales)
215
+ ):
216
+
217
+ if (
218
+ should_quantize_dropout
219
+ and quantizer_index > rand_quantize_dropout_index
220
+ ):
221
+ all_indices.append(null_indices)
222
+ continue
223
+
224
+ quantized, indices = layer(residual / scale)
225
+
226
+ quantized = quantized * scale
227
+
228
+ residual = residual - quantized.detach()
229
+ quantized_out = quantized_out + quantized
230
+
231
+ all_indices.append(indices)
232
+
233
+ # project out, if needed
234
+
235
+ quantized_out = self.project_out(quantized_out)
236
+
237
+ # stack all indices
238
+
239
+ all_indices = torch.stack(all_indices, dim=-1)
240
+
241
+ # channel first out
242
+
243
+ if self.is_channel_first:
244
+ (quantized_out,) = unpack(quantized_out, ps, "b * d")
245
+ (all_indices,) = unpack(all_indices, ps, "b * d")
246
+
247
+ quantized_out = rearrange(quantized_out, "b ... d -> b d ...")
248
+ all_indices = rearrange(all_indices, "b ... d -> b d ...")
249
+
250
+ # return
251
+
252
+ ret = (quantized_out, all_indices)
253
+
254
+ if not return_all_codes:
255
+ return ret
256
+
257
+ # whether to return all codes from all codebooks across layers
258
+
259
+ all_codes = self.get_codes_from_indices(all_indices)
260
+
261
+ # will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
262
+
263
+ return (*ret, all_codes)
264
+
265
+
266
+ # grouped residual fsq
267
+
268
+
269
+ class GroupedResidualFSQ(Module):
270
+ def __init__(self, *, dim, groups=1, accept_image_fmap=False, **kwargs):
271
+ super().__init__()
272
+ self.dim = dim
273
+ self.groups = groups
274
+ assert (dim % groups) == 0
275
+ dim_per_group = dim // groups
276
+
277
+ self.accept_image_fmap = accept_image_fmap
278
+
279
+ self.rvqs = nn.ModuleList([])
280
+
281
+ for _ in range(groups):
282
+ self.rvqs.append(ResidualFSQ(dim=dim_per_group, **kwargs))
283
+
284
+ self.codebook_size = self.rvqs[0].codebook_size
285
+
286
+ @property
287
+ def codebooks(self):
288
+ return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
289
+
290
+ @property
291
+ def split_dim(self):
292
+ return 1 if self.accept_image_fmap else -1
293
+
294
+ def get_codes_from_indices(self, indices):
295
+ codes = tuple(
296
+ rvq.get_codes_from_indices(chunk_indices)
297
+ for rvq, chunk_indices in zip(self.rvqs, indices)
298
+ )
299
+ return torch.stack(codes)
300
+
301
+ def get_output_from_indices(self, indices):
302
+ outputs = tuple(
303
+ rvq.get_output_from_indices(chunk_indices)
304
+ for rvq, chunk_indices in zip(self.rvqs, indices)
305
+ )
306
+ return torch.cat(outputs, dim=self.split_dim)
307
+
308
+ def forward(self, x, return_all_codes=False):
309
+ shape, split_dim, device = x.shape, self.split_dim, x.device
310
+ assert shape[split_dim] == self.dim
311
+
312
+ # split the feature dimension into groups
313
+
314
+ x = x.chunk(self.groups, dim=split_dim)
315
+
316
+ forward_kwargs = dict(
317
+ return_all_codes=return_all_codes,
318
+ rand_quantize_dropout_fixed_seed=(
319
+ get_maybe_sync_seed(device) if self.training else None
320
+ ),
321
+ )
322
+
323
+ # invoke residual vq on each group
324
+
325
+ out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
326
+ out = tuple(zip(*out))
327
+
328
+ # otherwise, get all the zipped outputs and combine them
329
+
330
+ quantized, all_indices, *maybe_all_codes = out
331
+
332
+ quantized = torch.cat(quantized, dim=split_dim)
333
+ all_indices = torch.stack(all_indices)
334
+
335
+ ret = (quantized, all_indices, *maybe_all_codes)
336
+ return ret
337
+
338
+
339
+ if __name__ == "__main__":
340
+ model = ResidualFSQ(
341
+ levels=[4, 4, 4, 4, 4, 4],
342
+ num_quantizers=1,
343
+ dim=30,
344
+ is_channel_first=True,
345
+ quantize_dropout=False,
346
+ )
347
+ x = torch.randn(2, 30, 10)
348
+ quantize, embed_ind = model(x)
349
+
350
+ emb_from_ind = model.get_output_from_indices(embed_ind.transpose(1, 2))
351
+
352
+ print(quantize == emb_from_ind.transpose(1, 2))
353
+
354
+ print("quantize shape", quantize.shape)
355
+ print("embed_ind", embed_ind)
models/bicodec_tokenizer/modules/speaker/__init__.py ADDED
File without changes
models/bicodec_tokenizer/modules/speaker/ecapa_tdnn.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Zhengyang Chen (chenzhengyang117@gmail.com)
2
+ # 2022 Hongji Wang (jijijiang77@gmail.com)
3
+ # 2023 Bing Han (hanbing97@sjtu.edu.cn)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """ This implementation is adapted from github repo:
18
+ https://github.com/lawlict/ECAPA-TDNN.
19
+ """
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+ from . import pooling_layers
26
+
27
+
28
+ class Res2Conv1dReluBn(nn.Module):
29
+ """
30
+ in_channels == out_channels == channels
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ channels,
36
+ kernel_size=1,
37
+ stride=1,
38
+ padding=0,
39
+ dilation=1,
40
+ bias=True,
41
+ scale=4,
42
+ ):
43
+ super().__init__()
44
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
45
+ self.scale = scale
46
+ self.width = channels // scale
47
+ self.nums = scale if scale == 1 else scale - 1
48
+
49
+ self.convs = []
50
+ self.bns = []
51
+ for i in range(self.nums):
52
+ self.convs.append(
53
+ nn.Conv1d(
54
+ self.width,
55
+ self.width,
56
+ kernel_size,
57
+ stride,
58
+ padding,
59
+ dilation,
60
+ bias=bias,
61
+ )
62
+ )
63
+ self.bns.append(nn.BatchNorm1d(self.width))
64
+ self.convs = nn.ModuleList(self.convs)
65
+ self.bns = nn.ModuleList(self.bns)
66
+
67
+ def forward(self, x):
68
+ out = []
69
+ spx = torch.split(x, self.width, 1)
70
+ sp = spx[0]
71
+ for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
72
+ # Order: conv -> relu -> bn
73
+ if i >= 1:
74
+ sp = sp + spx[i]
75
+ sp = conv(sp)
76
+ sp = bn(F.relu(sp))
77
+ out.append(sp)
78
+ if self.scale != 1:
79
+ out.append(spx[self.nums])
80
+ out = torch.cat(out, dim=1)
81
+
82
+ return out
83
+
84
+
85
+ """ Conv1d + BatchNorm1d + ReLU
86
+ """
87
+
88
+
89
+ class Conv1dReluBn(nn.Module):
90
+
91
+ def __init__(
92
+ self,
93
+ in_channels,
94
+ out_channels,
95
+ kernel_size=1,
96
+ stride=1,
97
+ padding=0,
98
+ dilation=1,
99
+ bias=True,
100
+ ):
101
+ super().__init__()
102
+ self.conv = nn.Conv1d(
103
+ in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
104
+ )
105
+ self.bn = nn.BatchNorm1d(out_channels)
106
+
107
+ def forward(self, x):
108
+ return self.bn(F.relu(self.conv(x)))
109
+
110
+
111
+ """ The SE connection of 1D case.
112
+ """
113
+
114
+
115
+ class SE_Connect(nn.Module):
116
+
117
+ def __init__(self, channels, se_bottleneck_dim=128):
118
+ super().__init__()
119
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
120
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
121
+
122
+ def forward(self, x):
123
+ out = x.mean(dim=2)
124
+ out = F.relu(self.linear1(out))
125
+ out = torch.sigmoid(self.linear2(out))
126
+ out = x * out.unsqueeze(2)
127
+
128
+ return out
129
+
130
+
131
+ """ SE-Res2Block of the ECAPA-TDNN architecture.
132
+ """
133
+
134
+
135
+ class SE_Res2Block(nn.Module):
136
+
137
+ def __init__(self, channels, kernel_size, stride, padding, dilation, scale):
138
+ super().__init__()
139
+ self.se_res2block = nn.Sequential(
140
+ Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
141
+ Res2Conv1dReluBn(
142
+ channels, kernel_size, stride, padding, dilation, scale=scale
143
+ ),
144
+ Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
145
+ SE_Connect(channels),
146
+ )
147
+
148
+ def forward(self, x):
149
+ return x + self.se_res2block(x)
150
+
151
+
152
+ class ECAPA_TDNN(nn.Module):
153
+
154
+ def __init__(
155
+ self,
156
+ channels=512,
157
+ feat_dim=80,
158
+ embed_dim=192,
159
+ pooling_func="ASTP",
160
+ global_context_att=False,
161
+ emb_bn=False,
162
+ ):
163
+ super().__init__()
164
+
165
+ self.layer1 = Conv1dReluBn(feat_dim, channels, kernel_size=5, padding=2)
166
+ self.layer2 = SE_Res2Block(
167
+ channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8
168
+ )
169
+ self.layer3 = SE_Res2Block(
170
+ channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8
171
+ )
172
+ self.layer4 = SE_Res2Block(
173
+ channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8
174
+ )
175
+
176
+ cat_channels = channels * 3
177
+ out_channels = 512 * 3
178
+ self.conv = nn.Conv1d(cat_channels, out_channels, kernel_size=1)
179
+ self.pool = getattr(pooling_layers, pooling_func)(
180
+ in_dim=out_channels, global_context_att=global_context_att
181
+ )
182
+ self.pool_out_dim = self.pool.get_out_dim()
183
+ self.bn = nn.BatchNorm1d(self.pool_out_dim)
184
+ self.linear = nn.Linear(self.pool_out_dim, embed_dim)
185
+ self.emb_bn = emb_bn
186
+ if emb_bn: # better in SSL for SV
187
+ self.bn2 = nn.BatchNorm1d(embed_dim)
188
+ else:
189
+ self.bn2 = nn.Identity()
190
+
191
+ def forward(self, x, return_latent=False):
192
+ x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T)
193
+
194
+ out1 = self.layer1(x)
195
+ out2 = self.layer2(out1)
196
+ out3 = self.layer3(out2)
197
+ out4 = self.layer4(out3)
198
+
199
+ out = torch.cat([out2, out3, out4], dim=1)
200
+ latent = F.relu(self.conv(out))
201
+ out = self.bn(self.pool(latent))
202
+ out = self.linear(out)
203
+ if self.emb_bn:
204
+ out = self.bn2(out)
205
+
206
+ if return_latent:
207
+ return out, latent
208
+ return out
209
+
210
+
211
+ def ECAPA_TDNN_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
212
+ return ECAPA_TDNN(
213
+ channels=1024,
214
+ feat_dim=feat_dim,
215
+ embed_dim=embed_dim,
216
+ pooling_func=pooling_func,
217
+ emb_bn=emb_bn,
218
+ )
219
+
220
+
221
+ def ECAPA_TDNN_GLOB_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
222
+ return ECAPA_TDNN(
223
+ channels=1024,
224
+ feat_dim=feat_dim,
225
+ embed_dim=embed_dim,
226
+ pooling_func=pooling_func,
227
+ global_context_att=True,
228
+ emb_bn=emb_bn,
229
+ )
230
+
231
+
232
+ def ECAPA_TDNN_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
233
+ return ECAPA_TDNN(
234
+ channels=512,
235
+ feat_dim=feat_dim,
236
+ embed_dim=embed_dim,
237
+ pooling_func=pooling_func,
238
+ emb_bn=emb_bn,
239
+ )
240
+
241
+
242
+ def ECAPA_TDNN_GLOB_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
243
+ return ECAPA_TDNN(
244
+ channels=512,
245
+ feat_dim=feat_dim,
246
+ embed_dim=embed_dim,
247
+ pooling_func=pooling_func,
248
+ global_context_att=True,
249
+ emb_bn=emb_bn,
250
+ )
251
+
252
+
253
+ if __name__ == "__main__":
254
+ x = torch.zeros(1, 200, 100)
255
+ model = ECAPA_TDNN_GLOB_c512(feat_dim=100, embed_dim=256, pooling_func="ASTP")
256
+ model.eval()
257
+ out, latent = model(x, True)
258
+ print(out.shape)
259
+ print(latent.shape)
260
+
261
+ num_params = sum(param.numel() for param in model.parameters())
262
+ print("{} M".format(num_params / 1e6))
263
+
264
+ # from thop import profile
265
+ # x_np = torch.randn(1, 200, 80)
266
+ # flops, params = profile(model, inputs=(x_np, ))
267
+ # print("FLOPs: {} G, Params: {} M".format(flops / 1e9, params / 1e6))
models/bicodec_tokenizer/modules/speaker/perceiver_encoder.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang (w.xinshawn@gmail.com)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
17
+
18
+ from collections import namedtuple
19
+ from functools import wraps
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from einops import rearrange, repeat
24
+ from einops.layers.torch import Rearrange
25
+ from packaging import version
26
+ from torch import einsum, nn
27
+
28
+
29
+ def exists(val):
30
+ return val is not None
31
+
32
+
33
+ def once(fn):
34
+ called = False
35
+
36
+ @wraps(fn)
37
+ def inner(x):
38
+ nonlocal called
39
+ if called:
40
+ return
41
+ called = True
42
+ return fn(x)
43
+
44
+ return inner
45
+
46
+
47
+ print_once = once(print)
48
+
49
+ # main class
50
+
51
+
52
+ class Attend(nn.Module):
53
+ def __init__(self, dropout=0.0, causal=False, use_flash=False):
54
+ super().__init__()
55
+ self.dropout = dropout
56
+ self.attn_dropout = nn.Dropout(dropout)
57
+
58
+ self.causal = causal
59
+ self.register_buffer("mask", None, persistent=False)
60
+
61
+ self.use_flash = use_flash
62
+ assert not (
63
+ use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
64
+ ), "in order to use flash attention, you must be using pytorch 2.0 or above"
65
+
66
+ # determine efficient attention configs for cuda and cpu
67
+ self.config = namedtuple(
68
+ "EfficientAttentionConfig",
69
+ ["enable_flash", "enable_math", "enable_mem_efficient"],
70
+ )
71
+ self.cpu_config = self.config(True, True, True)
72
+ self.cuda_config = None
73
+
74
+ if not torch.cuda.is_available() or not use_flash:
75
+ return
76
+
77
+ device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
78
+
79
+ if device_properties.major == 8 and device_properties.minor == 0:
80
+ print_once(
81
+ "A100 GPU detected, using flash attention if input tensor is on cuda"
82
+ )
83
+ self.cuda_config = self.config(True, False, False)
84
+ else:
85
+ print_once(
86
+ "Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
87
+ )
88
+ self.cuda_config = self.config(False, True, True)
89
+
90
+ def get_mask(self, n, device):
91
+ if exists(self.mask) and self.mask.shape[-1] >= n:
92
+ return self.mask[:n, :n]
93
+
94
+ mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
95
+ self.register_buffer("mask", mask, persistent=False)
96
+ return mask
97
+
98
+ def flash_attn(self, q, k, v, mask=None):
99
+ _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
100
+
101
+ # Recommended for multi-query single-key-value attention by Tri Dao
102
+ # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
103
+
104
+ if k.ndim == 3:
105
+ k = rearrange(k, "b ... -> b 1 ...").expand_as(q)
106
+
107
+ if v.ndim == 3:
108
+ v = rearrange(v, "b ... -> b 1 ...").expand_as(q)
109
+
110
+ # Check if mask exists and expand to compatible shape
111
+ # The mask is B L, so it would have to be expanded to B H N L
112
+
113
+ if exists(mask):
114
+ mask = rearrange(mask, "b j -> b 1 1 j")
115
+ mask = mask.expand(-1, heads, q_len, -1)
116
+
117
+ # Check if there is a compatible device for flash attention
118
+
119
+ config = self.cuda_config if is_cuda else self.cpu_config
120
+
121
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
122
+
123
+ with torch.backends.cuda.sdp_kernel(**config._asdict()):
124
+ out = F.scaled_dot_product_attention(
125
+ q,
126
+ k,
127
+ v,
128
+ attn_mask=mask,
129
+ dropout_p=self.dropout if self.training else 0.0,
130
+ is_causal=self.causal,
131
+ )
132
+
133
+ return out
134
+
135
+ def forward(self, q, k, v, mask=None):
136
+ """
137
+ einstein notation
138
+ b - batch
139
+ h - heads
140
+ n, i, j - sequence length (base sequence length, source, target)
141
+ d - feature dimension
142
+ """
143
+
144
+ n, device = q.shape[-2], q.device
145
+
146
+ scale = q.shape[-1] ** -0.5
147
+
148
+ if self.use_flash:
149
+ return self.flash_attn(q, k, v, mask=mask)
150
+
151
+ kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d"
152
+
153
+ # similarity
154
+
155
+ sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
156
+
157
+ # key padding mask
158
+
159
+ if exists(mask):
160
+ mask = rearrange(mask, "b j -> b 1 1 j")
161
+ sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
162
+
163
+ # causal mask
164
+
165
+ if self.causal:
166
+ causal_mask = self.get_mask(n, device)
167
+ sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
168
+
169
+ # attention
170
+
171
+ attn = sim.softmax(dim=-1)
172
+ attn = self.attn_dropout(attn)
173
+
174
+ # aggregate values
175
+
176
+ out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
177
+
178
+ return out
179
+
180
+
181
+ def Sequential(*mods):
182
+ return nn.Sequential(*filter(exists, mods))
183
+
184
+
185
+ def exists(x):
186
+ return x is not None
187
+
188
+
189
+ def default(val, d):
190
+ if exists(val):
191
+ return val
192
+ return d() if callable(d) else d
193
+
194
+
195
+ class RMSNorm(nn.Module):
196
+ def __init__(self, dim, scale=True, dim_cond=None):
197
+ super().__init__()
198
+ self.cond = exists(dim_cond)
199
+ self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None
200
+
201
+ self.scale = dim**0.5
202
+ self.gamma = nn.Parameter(torch.ones(dim)) if scale else None
203
+
204
+ def forward(self, x, cond=None):
205
+ gamma = default(self.gamma, 1)
206
+ out = F.normalize(x, dim=-1) * self.scale * gamma
207
+
208
+ if not self.cond:
209
+ return out
210
+
211
+ assert exists(cond)
212
+ gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1)
213
+ gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta))
214
+ return out * gamma + beta
215
+
216
+
217
+ class CausalConv1d(nn.Conv1d):
218
+ def __init__(self, *args, **kwargs):
219
+ super().__init__(*args, **kwargs)
220
+ (kernel_size,) = self.kernel_size
221
+ (dilation,) = self.dilation
222
+ (stride,) = self.stride
223
+
224
+ assert stride == 1
225
+ self.causal_padding = dilation * (kernel_size - 1)
226
+
227
+ def forward(self, x):
228
+ causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0)
229
+ return super().forward(causal_padded_x)
230
+
231
+
232
+ class GEGLU(nn.Module):
233
+ def forward(self, x):
234
+ x, gate = x.chunk(2, dim=-1)
235
+ return F.gelu(gate) * x
236
+
237
+
238
+ def FeedForward(dim, mult=4, causal_conv=False):
239
+ dim_inner = int(dim * mult * 2 / 3)
240
+
241
+ conv = None
242
+ if causal_conv:
243
+ conv = nn.Sequential(
244
+ Rearrange("b n d -> b d n"),
245
+ CausalConv1d(dim_inner, dim_inner, 3),
246
+ Rearrange("b d n -> b n d"),
247
+ )
248
+
249
+ return Sequential(
250
+ nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim)
251
+ )
252
+
253
+
254
+ class Attention(nn.Module):
255
+ def __init__(
256
+ self,
257
+ dim,
258
+ *,
259
+ dim_context=None,
260
+ causal=False,
261
+ dim_head=64,
262
+ heads=8,
263
+ dropout=0.0,
264
+ use_flash=False,
265
+ cross_attn_include_queries=False,
266
+ ):
267
+ super().__init__()
268
+ self.scale = dim_head**-0.5
269
+ self.heads = heads
270
+ self.cross_attn_include_queries = cross_attn_include_queries
271
+
272
+ dim_inner = dim_head * heads
273
+ dim_context = default(dim_context, dim)
274
+
275
+ self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash)
276
+ self.to_q = nn.Linear(dim, dim_inner, bias=False)
277
+ self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False)
278
+ self.to_out = nn.Linear(dim_inner, dim, bias=False)
279
+
280
+ def forward(self, x, context=None, mask=None):
281
+ h, has_context = self.heads, exists(context)
282
+
283
+ context = default(context, x)
284
+
285
+ if has_context and self.cross_attn_include_queries:
286
+ context = torch.cat((x, context), dim=-2)
287
+
288
+ q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
289
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
290
+
291
+ out = self.attend(q, k, v, mask=mask)
292
+
293
+ out = rearrange(out, "b h n d -> b n (h d)")
294
+ return self.to_out(out)
295
+
296
+
297
+ class PerceiverResampler(nn.Module):
298
+ def __init__(
299
+ self,
300
+ *,
301
+ dim,
302
+ depth=2,
303
+ dim_context=None,
304
+ num_latents=32,
305
+ dim_head=64,
306
+ heads=8,
307
+ ff_mult=4,
308
+ use_flash_attn=False,
309
+ ):
310
+ super().__init__()
311
+ dim_context = default(dim_context, dim)
312
+
313
+ self.proj_context = (
314
+ nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity()
315
+ )
316
+
317
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
318
+ nn.init.normal_(self.latents, std=0.02)
319
+
320
+ self.layers = nn.ModuleList([])
321
+ for _ in range(depth):
322
+ self.layers.append(
323
+ nn.ModuleList(
324
+ [
325
+ Attention(
326
+ dim=dim,
327
+ dim_head=dim_head,
328
+ heads=heads,
329
+ use_flash=use_flash_attn,
330
+ cross_attn_include_queries=True,
331
+ ),
332
+ FeedForward(dim=dim, mult=ff_mult),
333
+ ]
334
+ )
335
+ )
336
+
337
+ self.norm = RMSNorm(dim)
338
+
339
+ def forward(self, x, mask=None):
340
+ batch = x.shape[0]
341
+
342
+ x = self.proj_context(x)
343
+
344
+ latents = repeat(self.latents, "n d -> b n d", b=batch)
345
+
346
+ for attn, ff in self.layers:
347
+ latents = attn(latents, x, mask=mask) + latents
348
+ latents = ff(latents) + latents
349
+
350
+ return self.norm(latents)
351
+
352
+
353
+ if __name__ == "__main__":
354
+ model = PerceiverResampler(dim=256, dim_context=80)
355
+ x = torch.randn(8, 200, 80)
356
+ out = model(x)
357
+ print(out.shape) # [8, 32, 80]
358
+
359
+ num_params = sum(param.numel() for param in model.parameters())
360
+ print("{} M".format(num_params / 1e6))
models/bicodec_tokenizer/modules/speaker/pooling_layers.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Shuai Wang (wsstriving@gmail.com)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Pooling functions to aggregate frame-level deep features
16
+ into segment-level speaker embeddings
17
+
18
+ High-order statistics are surprisingly effective, TSDP acts similarly as TSTP,
19
+ even though we remove the mean statistic, on Voxceleb.
20
+ """
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+
26
+
27
+ class TAP(nn.Module):
28
+ """
29
+ Temporal average pooling, only first-order mean is considered
30
+ """
31
+
32
+ def __init__(self, in_dim=0, **kwargs):
33
+ super(TAP, self).__init__()
34
+ self.in_dim = in_dim
35
+
36
+ def forward(self, x):
37
+ pooling_mean = x.mean(dim=-1)
38
+ # To be compatable with 2D input
39
+ pooling_mean = pooling_mean.flatten(start_dim=1)
40
+ return pooling_mean
41
+
42
+ def get_out_dim(self):
43
+ self.out_dim = self.in_dim
44
+ return self.out_dim
45
+
46
+
47
+ class TSDP(nn.Module):
48
+ """
49
+ Temporal standard deviation pooling, only second-order std is considered
50
+ """
51
+
52
+ def __init__(self, in_dim=0, **kwargs):
53
+ super(TSDP, self).__init__()
54
+ self.in_dim = in_dim
55
+
56
+ def forward(self, x):
57
+ # The last dimension is the temporal axis
58
+ pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
59
+ pooling_std = pooling_std.flatten(start_dim=1)
60
+ return pooling_std
61
+
62
+ def get_out_dim(self):
63
+ self.out_dim = self.in_dim
64
+ return self.out_dim
65
+
66
+
67
+ class TSTP(nn.Module):
68
+ """
69
+ Temporal statistics pooling, concatenate mean and std, which is used in
70
+ x-vector
71
+ Comment: simple concatenation can not make full use of both statistics
72
+ """
73
+
74
+ def __init__(self, in_dim=0, **kwargs):
75
+ super(TSTP, self).__init__()
76
+ self.in_dim = in_dim
77
+
78
+ def forward(self, x):
79
+ # The last dimension is the temporal axis
80
+ pooling_mean = x.mean(dim=-1)
81
+ pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
82
+ pooling_mean = pooling_mean.flatten(start_dim=1)
83
+ pooling_std = pooling_std.flatten(start_dim=1)
84
+ stats = torch.cat((pooling_mean, pooling_std), 1)
85
+ return stats
86
+
87
+ def get_out_dim(self):
88
+ self.out_dim = self.in_dim * 2
89
+ return self.out_dim
90
+
91
+
92
+ class ASTP(nn.Module):
93
+ """ Attentive statistics pooling: Channel- and context-dependent
94
+ statistics pooling, first used in ECAPA_TDNN.
95
+ """
96
+
97
+ def __init__(self,
98
+ in_dim,
99
+ bottleneck_dim=128,
100
+ global_context_att=False,
101
+ **kwargs):
102
+ super(ASTP, self).__init__()
103
+ self.in_dim = in_dim
104
+ self.global_context_att = global_context_att
105
+
106
+ # Use Conv1d with stride == 1 rather than Linear, then we don't
107
+ # need to transpose inputs.
108
+ if global_context_att:
109
+ self.linear1 = nn.Conv1d(
110
+ in_dim * 3, bottleneck_dim,
111
+ kernel_size=1) # equals W and b in the paper
112
+ else:
113
+ self.linear1 = nn.Conv1d(
114
+ in_dim, bottleneck_dim,
115
+ kernel_size=1) # equals W and b in the paper
116
+ self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
117
+ kernel_size=1) # equals V and k in the paper
118
+
119
+ def forward(self, x):
120
+ """
121
+ x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
122
+ or a 4-dimensional tensor in resnet architecture (B,C,F,T)
123
+ 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
124
+ """
125
+ if len(x.shape) == 4:
126
+ x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
127
+ assert len(x.shape) == 3
128
+
129
+ if self.global_context_att:
130
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
131
+ context_std = torch.sqrt(
132
+ torch.var(x, dim=-1, keepdim=True) + 1e-7).expand_as(x)
133
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
134
+ else:
135
+ x_in = x
136
+
137
+ # DON'T use ReLU here! ReLU may be hard to converge.
138
+ alpha = torch.tanh(
139
+ self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
140
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
141
+ mean = torch.sum(alpha * x, dim=2)
142
+ var = torch.sum(alpha * (x**2), dim=2) - mean**2
143
+ std = torch.sqrt(var.clamp(min=1e-7))
144
+ return torch.cat([mean, std], dim=1)
145
+
146
+ def get_out_dim(self):
147
+ self.out_dim = 2 * self.in_dim
148
+ return self.out_dim
149
+
150
+
151
+ class MHASTP(torch.nn.Module):
152
+ """ Multi head attentive statistics pooling
153
+ Reference:
154
+ Self Multi-Head Attention for Speaker Recognition
155
+ https://arxiv.org/pdf/1906.09890.pdf
156
+ """
157
+
158
+ def __init__(self,
159
+ in_dim,
160
+ layer_num=2,
161
+ head_num=2,
162
+ d_s=1,
163
+ bottleneck_dim=64,
164
+ **kwargs):
165
+ super(MHASTP, self).__init__()
166
+ assert (in_dim % head_num
167
+ ) == 0 # make sure that head num can be divided by input_dim
168
+ self.in_dim = in_dim
169
+ self.head_num = head_num
170
+ d_model = int(in_dim / head_num)
171
+ channel_dims = [bottleneck_dim for i in range(layer_num + 1)]
172
+ if d_s > 1:
173
+ d_s = d_model
174
+ else:
175
+ d_s = 1
176
+ self.d_s = d_s
177
+ channel_dims[0], channel_dims[-1] = d_model, d_s
178
+ heads_att_trans = []
179
+ for i in range(self.head_num):
180
+ att_trans = nn.Sequential()
181
+ for i in range(layer_num - 1):
182
+ att_trans.add_module(
183
+ 'att_' + str(i),
184
+ nn.Conv1d(channel_dims[i], channel_dims[i + 1], 1, 1))
185
+ att_trans.add_module('tanh' + str(i), nn.Tanh())
186
+ att_trans.add_module(
187
+ 'att_' + str(layer_num - 1),
188
+ nn.Conv1d(channel_dims[layer_num - 1], channel_dims[layer_num],
189
+ 1, 1))
190
+ heads_att_trans.append(att_trans)
191
+ self.heads_att_trans = nn.ModuleList(heads_att_trans)
192
+
193
+ def forward(self, input):
194
+ """
195
+ input: a 3-dimensional tensor in xvector architecture
196
+ or a 4-dimensional tensor in resnet architecture
197
+ 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
198
+ """
199
+ if len(input.shape) == 4: # B x F x T
200
+ input = input.reshape(input.shape[0],
201
+ input.shape[1] * input.shape[2],
202
+ input.shape[3])
203
+ assert len(input.shape) == 3
204
+ bs, f_dim, t_dim = input.shape
205
+ chunks = torch.chunk(input, self.head_num, 1)
206
+ # split
207
+ chunks_out = []
208
+ # for i in range(self.head_num):
209
+ # att_score = self.heads_att_trans[i](chunks[i])
210
+ for i, layer in enumerate(self.heads_att_trans):
211
+ att_score = layer(chunks[i])
212
+ alpha = F.softmax(att_score, dim=-1)
213
+ mean = torch.sum(alpha * chunks[i], dim=2)
214
+ var = torch.sum(alpha * chunks[i]**2, dim=2) - mean**2
215
+ std = torch.sqrt(var.clamp(min=1e-7))
216
+ chunks_out.append(torch.cat((mean, std), dim=1))
217
+ out = torch.cat(chunks_out, dim=1)
218
+ return out
219
+
220
+ def get_out_dim(self):
221
+ self.out_dim = 2 * self.in_dim
222
+ return self.out_dim
223
+
224
+
225
+ class MQMHASTP(torch.nn.Module):
226
+ """ An attentive pooling
227
+ Reference:
228
+ multi query multi head attentive statistics pooling
229
+ https://arxiv.org/pdf/2110.05042.pdf
230
+ Args:
231
+ in_dim: the feature dimension of input
232
+ layer_num: the number of layer in the pooling layer
233
+ query_num: the number of querys
234
+ head_num: the number of heads
235
+ bottleneck_dim: the bottleneck dimension
236
+
237
+ SA (H = 1, Q = 1, n = 2, d_s = 1) ref:
238
+ https://www.danielpovey.com/files/2018_interspeech_xvector_attention.pdf
239
+ MHA (H > 1, Q = 1, n = 1, d_s = 1) ref:
240
+ https://arxiv.org/pdf/1906.09890.pdf
241
+ AS (H = 1, Q > 1, n = 2, d_s = 1) ref:
242
+ https://arxiv.org/pdf/1803.10963.pdf
243
+ VSA (H = 1, Q > 1, n = 2, d_s = d_h) ref:
244
+ http://www.interspeech2020.org/uploadfile/pdf/Mon-2-10-5.pdf
245
+ """
246
+
247
+ def __init__(self,
248
+ in_dim,
249
+ layer_num=2,
250
+ query_num=2,
251
+ head_num=8,
252
+ d_s=2,
253
+ bottleneck_dim=64,
254
+ **kwargs):
255
+ super(MQMHASTP, self).__init__()
256
+ self.n_query = nn.ModuleList([
257
+ MHASTP(in_dim,
258
+ layer_num=layer_num,
259
+ head_num=head_num,
260
+ d_s=d_s,
261
+ bottleneck_dim=bottleneck_dim) for i in range(query_num)
262
+ ])
263
+ self.query_num = query_num
264
+ self.in_dim = in_dim
265
+
266
+ def forward(self, input):
267
+ """
268
+ input: a 3-dimensional tensor in xvector architecture
269
+ or a 4-dimensional tensor in resnet architecture
270
+ 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
271
+ """
272
+ if len(input.shape) == 4: # B x F x T
273
+ input = input.reshape(input.shape[0],
274
+ input.shape[1] * input.shape[2],
275
+ input.shape[3])
276
+ assert len(input.shape) == 3
277
+ res = []
278
+ for i, layer in enumerate(self.n_query):
279
+ res.append(layer(input))
280
+ out = torch.cat(res, dim=-1)
281
+ return out
282
+
283
+ def get_out_dim(self):
284
+ self.out_dim = self.in_dim * 2 * self.query_num
285
+ return self.out_dim
286
+
287
+
288
+ if __name__ == '__main__':
289
+ data = torch.randn(16, 512, 10, 35)
290
+ # model = StatisticsPooling()
291
+ model = MQMHASTP(512 * 10)
292
+ model = MHASTP(512 * 10)
293
+ model = MQMHASTP(512 * 10, context=False)
294
+ print(model)
295
+
296
+ out = model(data)
297
+ print(out.shape)
298
+ print(model.get_out_dim())
models/bicodec_tokenizer/modules/speaker/speaker_encoder.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang (w.xinshawn@gmail.com)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from typing import List, Tuple
20
+ from ..fsq.residual_fsq import ResidualFSQ
21
+ from .ecapa_tdnn import ECAPA_TDNN_GLOB_c512
22
+ from .perceiver_encoder import PerceiverResampler
23
+
24
+ """
25
+ x-vector + d-vector
26
+ """
27
+
28
+
29
+ class SpeakerEncoder(nn.Module):
30
+ """
31
+
32
+ Args:
33
+ input_dim (int): acoustic feature dimension
34
+ out_dim (int): output dimension of x-vector and d-vector
35
+ latent_dim (int): latent dimension before quantization
36
+ token_num (int): sequence length of speaker tokens
37
+ fsq_levels (List[int]): number of levels for each quantizer
38
+ fsq_num_quantizers (int): number of quantizers
39
+
40
+ Return:
41
+ speaker_embs: (B, T2, out_dim)
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ input_dim: int = 100,
47
+ out_dim: int = 512,
48
+ latent_dim: int = 128,
49
+ token_num: int = 32,
50
+ fsq_levels: List[int] = [4, 4, 4, 4, 4, 4],
51
+ fsq_num_quantizers: int = 1,
52
+ ):
53
+ super(SpeakerEncoder, self).__init__()
54
+
55
+ self.speaker_encoder = ECAPA_TDNN_GLOB_c512(
56
+ feat_dim=input_dim, embed_dim=out_dim
57
+ )
58
+ self.perceiver_sampler = PerceiverResampler(
59
+ dim=latent_dim, dim_context=512 * 3, num_latents=token_num
60
+ )
61
+ self.quantizer = ResidualFSQ(
62
+ levels=fsq_levels,
63
+ num_quantizers=fsq_num_quantizers,
64
+ dim=latent_dim,
65
+ is_channel_first=True,
66
+ quantize_dropout=False,
67
+ )
68
+
69
+ self.project = nn.Linear(latent_dim * token_num, out_dim)
70
+
71
+ def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor:
72
+ zq = self.quantizer.get_codes_from_indices(indices.transpose(1, 2))
73
+ return zq.transpose(1, 2)
74
+
75
+ def get_indices(self, mels: torch.Tensor) -> torch.Tensor:
76
+ mels = mels.transpose(1, 2)
77
+ x = self.perceiver_sampler(mels).transpose(1, 2)
78
+ zq, indices = self.quantizer(x)
79
+ return indices
80
+
81
+ def forward(self, mels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
82
+ """
83
+ Args:
84
+ mels: (B, D_mel, T1)
85
+
86
+ Return:
87
+ x_vector: (B, out_dim)
88
+ d_vector: (B, out_dim)
89
+ """
90
+ # mels = mels.transpose(1,2)
91
+
92
+ x_vector, features = self.speaker_encoder(mels, True)
93
+ x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2)
94
+ zq, indices = self.quantizer(x) # zq: (B, latent_dim, T2, latent_dim)
95
+ x = zq.reshape(zq.shape[0], -1)
96
+ d_vector = self.project(x)
97
+
98
+ return x_vector, d_vector
99
+
100
+ def tokenize(self, mels: torch.Tensor) -> torch.Tensor:
101
+ """tokenize the input mel spectrogram"""
102
+ _, features = self.speaker_encoder(mels, True)
103
+ x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2)
104
+ zq, indices = self.quantizer(x)
105
+ return indices
106
+
107
+ def detokenize(self, indices: torch.Tensor) -> torch.Tensor:
108
+ """detokenize the input indices to d-vector"""
109
+ zq = self.quantizer.get_output_from_indices(indices.transpose(1, 2)).transpose(1, 2)
110
+ x = zq.reshape(zq.shape[0], -1)
111
+ d_vector = self.project(x)
112
+ return d_vector
113
+
114
+ if __name__ == "__main__":
115
+ model = SpeakerEncoder(
116
+ input_dim=100,
117
+ latent_dim=128,
118
+ token_num=32,
119
+ fsq_levels=[4, 4, 4, 4, 4, 4],
120
+ fsq_num_quantizers=1,
121
+ )
122
+ mel = torch.randn(8, 200, 100)
123
+ x_vector, d_vector = model(mel)
124
+ print("x-vector shape", x_vector.shape)
125
+ print("d-vector shape", d_vector.shape)
126
+
127
+ indices = model.tokenize(mel)
128
+ print("indices shape", indices.shape)
129
+ d_vector_post = model.detokenize(indices)
130
+ print("d-vector shape", d_vector_post.shape)
131
+ if d_vector_post.all() == d_vector.all():
132
+ print("d-vector post and d-vector are the same")
133
+ else:
134
+ print("d-vector post and d-vector are different")
135
+ num_params = sum(param.numel() for param in model.parameters())
136
+ print("{} M".format(num_params / 1e6))
models/bicodec_tokenizer/modules/vq/factorized_vector_quantize.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang (w.xinshawn@gmail.com)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Heavily based on https://github.com/lucidrains/vector-quantize-pytorch
17
+
18
+
19
+ from typing import Any, Dict
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from einops import rearrange
25
+ from torch.nn.utils import weight_norm
26
+
27
+
28
+ def WNConv1d(*args, **kwargs):
29
+ return weight_norm(nn.Conv1d(*args, **kwargs))
30
+
31
+
32
+ def ema_inplace(moving_avg, new, decay):
33
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
34
+
35
+
36
+ class FactorizedVectorQuantize(nn.Module):
37
+ def __init__(
38
+ self,
39
+ input_dim: int,
40
+ codebook_size: int,
41
+ codebook_dim: int,
42
+ commitment: float,
43
+ codebook_loss_weight: float = 1.0,
44
+ decay: float = 0.99,
45
+ threshold_ema_dead_code: float = 2,
46
+ momentum: float = 0.99,
47
+ **kwargs,
48
+ ):
49
+ super().__init__()
50
+ self.input_dim = input_dim
51
+ self.codebook_size = codebook_size
52
+ self.codebook_dim = codebook_dim
53
+ self.commitment = commitment
54
+ self.codebook_loss_weight = codebook_loss_weight
55
+ self.decay = decay
56
+ self.threshold_ema_dead_code = threshold_ema_dead_code
57
+ self.momentum = momentum
58
+
59
+ if input_dim != self.codebook_dim:
60
+ self.in_project = WNConv1d(input_dim, self.codebook_dim, kernel_size=1)
61
+ self.out_project = WNConv1d(self.codebook_dim, input_dim, kernel_size=1)
62
+
63
+ else:
64
+ self.in_project = nn.Identity()
65
+ self.out_project = nn.Identity()
66
+
67
+ self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
68
+ self.register_buffer("cluster_size", torch.zeros(self.codebook_size))
69
+
70
+ def forward(self, z: torch.Tensor) -> Dict[str, Any]:
71
+ """Quantized the input tensor using a fixed codebook and returns
72
+ the corresponding codebook vectors
73
+
74
+ Parameters
75
+ ----------
76
+ z : Tensor[B x D x T]
77
+
78
+ Returns
79
+ -------
80
+ Tensor[B x D x T]
81
+ Quantized continuous representation of input
82
+ Tensor[1]
83
+ Commitment loss to train encoder to predict vectors closer to codebook
84
+ entries
85
+ Tensor[1]
86
+ Codebook loss to update the codebook
87
+ Tensor[B x T]
88
+ Codebook indices (quantized discrete representation of input)
89
+ Tensor[B x D x T]
90
+ Projected latents (continuous representation of input before quantization)
91
+ """
92
+ # transpose since we use linear
93
+
94
+ # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
95
+ z_e = self.in_project(z)
96
+ z_q, indices, dists = self.decode_latents(z_e)
97
+
98
+ # statistic the usage of codes
99
+ embed_onehot = F.one_hot(indices, self.codebook_size).type(z_e.dtype)
100
+ avg_probs = torch.mean(embed_onehot.reshape(-1, self.codebook_size), dim=0)
101
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
102
+
103
+ active_num = (embed_onehot.sum(0).sum(0) > 0).sum()
104
+ if self.training:
105
+ # We do the expiry of code at that point as buffers are in sync
106
+ # and all the workers will take the same decision.
107
+ ema_inplace(self.cluster_size, embed_onehot.sum(0).sum(0), self.decay)
108
+ active_num = sum(self.cluster_size > self.threshold_ema_dead_code)
109
+
110
+ if self.training:
111
+ commit_loss = (
112
+ F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
113
+ * self.commitment
114
+ )
115
+
116
+ codebook_loss = (
117
+ F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
118
+ * self.codebook_loss_weight
119
+ )
120
+
121
+ else:
122
+ commit_loss = torch.zeros(0, device=z.device)
123
+ codebook_loss = torch.zeros(0, device=z.device)
124
+
125
+ z_q = (
126
+ z_e + (z_q - z_e).detach()
127
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
128
+
129
+ z_q = self.out_project(z_q)
130
+
131
+ vq_loss = (commit_loss + codebook_loss).mean()
132
+
133
+ return {
134
+ "z_q": z_q,
135
+ "indices": indices,
136
+ "dists": dists,
137
+ "vq_loss": vq_loss,
138
+ "perplexity": perplexity,
139
+ "active_num": active_num.float(),
140
+ }
141
+
142
+ def vq2emb(self, vq, out_proj=True):
143
+ emb = self.embed_code(vq)
144
+ if out_proj:
145
+ emb = self.out_project(emb)
146
+ return emb
147
+
148
+ def tokenize(self, z: torch.Tensor) -> torch.Tensor:
149
+ """tokenize the input tensor"""
150
+ z_e = self.in_project(z)
151
+ _, indices, _ = self.decode_latents(z_e)
152
+ return indices
153
+
154
+ def detokenize(self, indices):
155
+ """detokenize the input indices"""
156
+ z_q = self.decode_code(indices)
157
+ z_q = self.out_project(z_q)
158
+ return z_q
159
+
160
+ def get_emb(self):
161
+ return self.codebook.weight
162
+
163
+ def embed_code(self, embed_id):
164
+ return F.embedding(embed_id, self.codebook.weight)
165
+
166
+ def decode_code(self, embed_id):
167
+ return self.embed_code(embed_id).transpose(1, 2)
168
+
169
+ def decode_latents(self, latents):
170
+ encodings = rearrange(latents, "b d t -> (b t) d")
171
+ codebook = self.codebook.weight
172
+
173
+ # L2 normalize encodings and codebook
174
+ encodings = F.normalize(encodings)
175
+ codebook = F.normalize(codebook)
176
+
177
+ # Compute euclidean distance between encodings and codebook,
178
+ # with L2 normalization, the distance is equal to cosine distance
179
+ dist = (
180
+ encodings.pow(2).sum(1, keepdim=True)
181
+ - 2 * encodings @ codebook.t()
182
+ + codebook.pow(2).sum(1, keepdim=True).t()
183
+ )
184
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
185
+ z_q = self.decode_code(indices)
186
+
187
+ return z_q, indices, dist
models/bicodec_tokenizer/spark_detokenizer.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Time :2025/3/29 10:34
3
+ # Author :Hui Huang
4
+ import os
5
+ from typing import Literal
6
+
7
+ import torch
8
+ from .base_model import SparkBaseModel
9
+ from .batch_processor import AsyncBatchEngine
10
+ from .tokenizer_utils import get_dtype
11
+ from .modules.encoder_decoder.feat_decoder import Decoder
12
+ from .modules.encoder_decoder.wave_generator import WaveGenerator
13
+ from .modules.speaker.speaker_encoder import SpeakerEncoder
14
+ from .modules.vq.factorized_vector_quantize import FactorizedVectorQuantize
15
+
16
+ __all__ = ["SparkDeTokenizer"]
17
+
18
+
19
+ class SparkDeTokenizerModel(SparkBaseModel):
20
+ def __init__(self, config):
21
+ super().__init__()
22
+
23
+ self.quantizer = FactorizedVectorQuantize(**config["quantizer"])
24
+ self.prenet = Decoder(**config["prenet"])
25
+ self.decoder = WaveGenerator(**config["decoder"])
26
+ self.speaker_encoder = SpeakerEncoder(**config["speaker_encoder"])
27
+
28
+ @torch.no_grad()
29
+ def forward(
30
+ self,
31
+ semantic_tokens: torch.Tensor,
32
+ global_tokens: torch.Tensor
33
+ ) -> torch.Tensor:
34
+ z_q = self.quantizer.detokenize(semantic_tokens)
35
+ d_vector = self.speaker_encoder.detokenize(global_tokens)
36
+ x = self.prenet(z_q, d_vector)
37
+ x = x + d_vector.unsqueeze(-1)
38
+ wav_recon = self.decoder(x)
39
+ return wav_recon.detach()
40
+
41
+
42
+ class SparkDeTokenizer:
43
+ def __init__(
44
+ self,
45
+ model_path: str,
46
+ device: Literal["cpu", "cuda", "mps"] | str = "cpu",
47
+ batch_size: int = 32,
48
+ wait_timeout: float = 0.01):
49
+ self.device = torch.device(device)
50
+ self.model = SparkDeTokenizerModel.from_pretrained(model_path).to(self.device)
51
+ self.device_type = device
52
+ self.dtype = get_dtype(self.device_type)
53
+ self._batch_processor = AsyncBatchEngine(
54
+ processing_function=self.batch_detokenize_async,
55
+ batch_size=batch_size,
56
+ wait_timeout=wait_timeout
57
+ )
58
+
59
+ @torch.no_grad()
60
+ def detokenize(
61
+ self,
62
+ semantic_tokens: torch.Tensor,
63
+ global_tokens: torch.Tensor
64
+ ) -> torch.Tensor:
65
+ with torch.amp.autocast(self.device_type, dtype=self.dtype):
66
+ output = self.model(
67
+ semantic_tokens.to(self.device),
68
+ global_tokens.to(self.device)
69
+ )
70
+ return output
71
+
72
+ async def batch_detokenize_async(self, requests: list[dict[str, torch.Tensor]]) -> list[dict[str, torch.Tensor]]:
73
+ semantic_tokens, global_tokens = [], []
74
+ lengths = []
75
+ for request in requests:
76
+ semantic_tokens.append(request["semantic_tokens"])
77
+ global_tokens.append(request["global_tokens"])
78
+ lengths.append(len(request['semantic_tokens']))
79
+ # Concatenate tokens for batch processing
80
+ global_tokens = torch.stack(global_tokens, dim=0)
81
+ semantic_tokens = torch.nn.utils.rnn.pad_sequence(
82
+ semantic_tokens, batch_first=True, padding_value=0
83
+ )
84
+ # print(f"tokenizer global_tokens shape is {global_tokens.shape}")
85
+ # print(f"tokenizer semantic_tokens shape is {semantic_tokens.shape}")
86
+ audios = self.detokenize(
87
+ semantic_tokens=semantic_tokens,
88
+ global_tokens=global_tokens
89
+ ).detach().cpu()
90
+ # Prepare responses
91
+ responses = []
92
+ for i in range(len(requests)):
93
+ audio = audios[i, :, :(lengths[i] * 320)] # 大概一个token对应audio长度320
94
+ responses.append({
95
+ "audio": audio,
96
+ })
97
+
98
+ if self.device.type == "cuda":
99
+ torch.cuda.empty_cache()
100
+ return responses
101
+
102
+ async def detokenize_async(self, request: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
103
+ output = await self._batch_processor.add_request(
104
+ single_input=request
105
+ )
106
+ return output.get("feature")
models/bicodec_tokenizer/spark_tokenizer.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Time :2025/3/29 10:30
3
+ # Author :Hui Huang
4
+ import os
5
+ from typing import Literal, Optional, Tuple, Dict, Any, List, Union
6
+
7
+ import torch
8
+ import torchaudio
9
+ import torchaudio.transforms as TT
10
+ from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor
11
+ import numpy as np
12
+ from loguru import logger
13
+ from pathlib import Path
14
+
15
+ # ----------------- 假设这些模块位于你的项目路径下 -----------------
16
+ from .utils.file import load_config
17
+ from .utils.audio import load_audio
18
+ from .models.bicodec import BiCodec
19
+ from .base_model import SparkBaseModel
20
+ from .batch_processor import AsyncBatchEngine
21
+ # ---------------------------------------------------------------
22
+
23
+ __all__ = ["SparkTokenizer"]
24
+
25
+
26
+ class SparkTokenizer:
27
+ def __init__(
28
+ self,
29
+ model_path: str,
30
+ device: Literal["cpu", "cuda", "mps"] | str = "cuda",
31
+ attn_implementation: Optional[Literal["sdpa", "flash_attention_2", "eager"]] = "eager",
32
+ batch_size: int = 32,
33
+ wait_timeout: float = 0.01,
34
+ ):
35
+ self.device = torch.device(device)
36
+ self.model_dir = Path(model_path)
37
+
38
+ # 1. 加载配置
39
+ self.config = load_config(self.model_dir / "config.yaml")
40
+ self.device_type = "cuda" if "cuda" in str(device) else "cpu"
41
+ self.dtype = torch.float16 if self.device_type == "cuda" else torch.float32
42
+ self.target_sample_rate = self.config.get("sample_rate", 16000)
43
+
44
+ # 2. 加载模型
45
+ wav2vec_path = self.model_dir / "wav2vec2-large-xlsr-53"
46
+ self.processor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_path)
47
+ self.feature_extractor = Wav2Vec2Model.from_pretrained(
48
+ wav2vec_path,
49
+ attn_implementation=attn_implementation,
50
+ torch_dtype=self.dtype
51
+ )
52
+ self.feature_extractor.config.output_hidden_states = True
53
+ self.feature_extractor.to(self.device)
54
+ self.feature_extractor.eval()
55
+
56
+ # BiCodec model
57
+ self.model = (
58
+ BiCodec.load_from_checkpoint(str(self.model_dir)).to(self.device).half()
59
+ )
60
+ self.model.eval()
61
+
62
+ # 异步处理引擎
63
+ self._batch_processor = AsyncBatchEngine(
64
+ processing_function=self.batch_tokenize_async,
65
+ batch_size=batch_size,
66
+ wait_timeout=wait_timeout
67
+ )
68
+
69
+ def _to_ndarray(self, audio_input: Union[str, Path, torch.Tensor]) -> np.ndarray:
70
+ """
71
+ 将输入(路径或Tensor)统一转换为指定采样率的 numpy 数组。
72
+ """
73
+ if isinstance(audio_input, (str, Path)):
74
+ # 如果是路径,直接使用原有的 load_audio
75
+ wav = load_audio(
76
+ str(audio_input),
77
+ sampling_rate=self.target_sample_rate,
78
+ volume_normalize=self.config.get("volume_normalize", True),
79
+ )
80
+ elif isinstance(audio_input, torch.Tensor):
81
+ # 如果是 Tensor
82
+ wav = audio_input.detach().cpu().float()
83
+
84
+ # 处理通道: [C, T] -> [T]
85
+ if wav.ndim > 1:
86
+ wav = torch.mean(wav, dim=0)
87
+
88
+ # 这里默认输入的 Tensor 采样率已经是 self.target_sample_rate
89
+ # 如果需要在这里做重采样,需要额外传入输入采样率参数
90
+ wav = wav.numpy()
91
+
92
+ # 可选:音量归一化逻辑(如果 Tensor 没归一化)
93
+ if self.config.get("volume_normalize", True):
94
+ max_val = np.abs(wav).max()
95
+ if max_val > 0:
96
+ wav = wav / max_val * 0.9
97
+ else:
98
+ raise ValueError(f"Unsupported audio type: {type(audio_input)}")
99
+
100
+ return wav
101
+
102
+ def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
103
+ """获取参考音频片段"""
104
+ ref_segment_length = (
105
+ int(self.target_sample_rate * self.config["ref_segment_duration"])
106
+ // self.config["latent_hop_length"]
107
+ * self.config["latent_hop_length"]
108
+ )
109
+ wav_length = len(wav)
110
+
111
+ if ref_segment_length > wav_length:
112
+ wav = np.tile(wav, ref_segment_length // wav_length + 1)
113
+
114
+ return wav[:ref_segment_length]
115
+
116
+ def process_audio(self, audio_input: Union[str, torch.Tensor], ref_audio_input: Union[str, torch.Tensor] = None) -> Tuple[np.ndarray, torch.Tensor]:
117
+ """
118
+ 处理音频和参考音频。
119
+ """
120
+ wav = self._to_ndarray(audio_input)
121
+
122
+ if ref_audio_input is None:
123
+ wav_ref_np = self.get_ref_clip(wav)
124
+ else:
125
+ ref_wav = self._to_ndarray(ref_audio_input)
126
+ wav_ref_np = self.get_ref_clip(ref_wav)
127
+
128
+ wav_ref = torch.from_numpy(wav_ref_np).unsqueeze(0).float()
129
+ return wav, wav_ref
130
+
131
+ def extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor:
132
+ """提取 wav2vec2 特征"""
133
+ # processor 期望是 list of numpy
134
+ inputs = self.processor(
135
+ [w.cpu().numpy() for w in wavs],
136
+ sampling_rate=16000,
137
+ return_tensors="pt",
138
+ padding=True,
139
+ ).input_values
140
+
141
+ with torch.no_grad():
142
+ with torch.amp.autocast(self.device_type, dtype=self.dtype):
143
+ feat = self.feature_extractor(inputs.to(self.feature_extractor.device))
144
+
145
+ feats_mix = (
146
+ feat.hidden_states[11] + feat.hidden_states[14] + feat.hidden_states[16]
147
+ ) / 3
148
+
149
+ return feats_mix
150
+
151
+ @torch.no_grad()
152
+ def tokenize(self, audios: List[Union[str, torch.Tensor]]):
153
+ """
154
+ 支持音频路径列表或 Tensor 列表。
155
+ """
156
+ batch_wavs = []
157
+ batch_ref_wavs = []
158
+
159
+ for audio_item in audios:
160
+ wav, wav_ref = self.process_audio(audio_input=audio_item, ref_audio_input=audio_item)
161
+ batch_wavs.append(torch.from_numpy(wav).float())
162
+ batch_ref_wavs.append(wav_ref.squeeze(0))
163
+
164
+ # Padding wavs
165
+ wav_lengths = [len(w) for w in batch_wavs]
166
+ max_wav_len = max(wav_lengths)
167
+ padded_wavs = torch.zeros(len(batch_wavs), max_wav_len, dtype=self.dtype).to(self.device)
168
+ for i, w in enumerate(batch_wavs):
169
+ padded_wavs[i, :len(w)] = w.to(self.dtype)
170
+
171
+ # Padding ref_wavs
172
+ ref_lengths = [len(w) for w in batch_ref_wavs]
173
+ max_ref_len = max(ref_lengths)
174
+ padded_ref_wavs = torch.zeros(len(batch_ref_wavs), max_ref_len, dtype=self.dtype).to(self.device)
175
+ for i, w in enumerate(batch_ref_wavs):
176
+ padded_ref_wavs[i, :len(w)] = w.to(self.dtype)
177
+
178
+ # 提取特征
179
+ feats = self.extract_wav2vec2_features(padded_wavs)
180
+
181
+ batch = {
182
+ "wav": padded_wavs,
183
+ "ref_wav": padded_ref_wavs,
184
+ "feat": feats,
185
+ }
186
+
187
+ semantic_tokens, global_tokens = self.model.tokenize(batch)
188
+
189
+ if self.device.type == "cuda":
190
+ torch.cuda.empty_cache()
191
+
192
+ return {"semantic_tokens": semantic_tokens, "global_tokens": global_tokens}
193
+
194
+ async def batch_tokenize_async(self, audios: list) -> list[dict[str, torch.Tensor]]:
195
+ tokenized = self.tokenize(audios)
196
+ responses = []
197
+ for i in range(len(audios)):
198
+ responses.append({
199
+ "global_tokens": tokenized["global_tokens"][i],
200
+ "semantic_tokens": tokenized["semantic_tokens"][i]
201
+ })
202
+ return responses
203
+
204
+ async def tokenize_async(self, audio: Union[str, torch.Tensor]) -> dict[str, torch.Tensor]:
205
+ output = await self._batch_processor.add_request(
206
+ single_input=audio
207
+ )
208
+ return output
209
+
210
+ # ------------------------------------------------------------------
211
+ # 测试用例
212
+ # ------------------------------------------------------------------
213
+ if __name__ == "__main__":
214
+ # 配置你的模型路径
215
+ MODEL_DIR = "/data/yumu/model/ark_tts_v1"
216
+
217
+ # 初始化
218
+ # 注意:在没有真实环境时,这行会因为找不到文件报错,请在有环境的地方运行
219
+ tokenizer = SparkTokenizer(model_path=MODEL_DIR, device="cuda" if torch.cuda.is_available() else "cpu")
220
+
221
+ # 准备数据:一个是本地存在的 wav 路径,一个是构造的 Tensor
222
+ dummy_wav_path = "/data/yumu/arktts/dufu.wav"
223
+ # 构造一个 16kHz 的 2 秒音频 Tensor (假设模型要求16k)
224
+ import torchaudio
225
+ dummy_tensor, sr = torchaudio.load(dummy_wav_path)
226
+
227
+ # 1. 测试路径输入
228
+ print("Testing path input...")
229
+ if os.path.exists(dummy_wav_path):
230
+ res1 = tokenizer.tokenize([dummy_wav_path])
231
+ print(f"Path results: {res1['semantic_tokens'].shape}")
232
+
233
+ # 2. 测试 Tensor 输入
234
+ print("Testing tensor input...")
235
+ res2 = tokenizer.tokenize([dummy_tensor])
236
+ print(f"Tensor results: {res2['semantic_tokens'].shape}")
237
+
238
+ # 3. 测试混合输入 (List 包含 str 和 Tensor)
239
+ print("Testing mixed input...")
240
+ # 为了演示,我们传两个相同的 tensor
241
+ res3 = tokenizer.tokenize([dummy_tensor, dummy_tensor])
242
+ print(f"Mixed results: {res3['semantic_tokens'].shape}")
243
+
244
+ print("All tests passed!")
models/bicodec_tokenizer/tokenizer_utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Time :2025/3/29 10:27
3
+ # Author :Hui Huang
4
+ from omegaconf import OmegaConf, DictConfig
5
+ import torch
6
+
7
+
8
+ def load_config(config_path: str) -> DictConfig:
9
+ """Loads a configuration file and optionally merges it with a base configuration.
10
+
11
+ Args:
12
+ config_path (Path): Path to the configuration file.
13
+ """
14
+ # Load the initial configuration from the given path
15
+ config = OmegaConf.load(config_path)
16
+
17
+ # Check if there is a base configuration specified and merge if necessary
18
+ if config.get("base_config", None) is not None:
19
+ base_config = OmegaConf.load(config["base_config"])
20
+ config = OmegaConf.merge(base_config, config)
21
+
22
+ return config
23
+
24
+
25
+ def gpu_supports_fp16() -> bool:
26
+ # 1. 确保 CUDA 可用
27
+ if not torch.cuda.is_available():
28
+ return False
29
+
30
+ # 2. 获取设备的 compute capability
31
+ major, minor = torch.cuda.get_device_capability()
32
+
33
+ # 3. 判断是否 >= 5.3
34
+ if major > 5 or (major == 5 and minor >= 3):
35
+ return True
36
+ else:
37
+ return False
38
+
39
+
40
+ def get_dtype(device: str):
41
+ if device.startswith('cuda') and gpu_supports_fp16():
42
+ return torch.float16
43
+ else:
44
+ return torch.float32
models/bicodec_tokenizer/utils/__init__.py ADDED
File without changes
models/bicodec_tokenizer/utils/audio.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang (w.xinshawn@gmail.com)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Description:
17
+ This script contains a collection of functions designed to handle various
18
+ audio processing.
19
+ """
20
+
21
+ import random
22
+ import soxr
23
+ import soundfile
24
+ import torch
25
+ import torchaudio
26
+ import numpy as np
27
+
28
+ from pathlib import Path
29
+ from typing import Tuple
30
+ from numpy.lib.stride_tricks import sliding_window_view
31
+
32
+
33
+ def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray:
34
+ """
35
+ Normalize the volume of an audio signal.
36
+
37
+ Parameters:
38
+ audio (numpy array): Input audio signal array.
39
+ coeff (float): Target coefficient for normalization, default is 0.2.
40
+
41
+ Returns:
42
+ numpy array: The volume-normalized audio signal.
43
+ """
44
+ # Sort the absolute values of the audio signal
45
+ temp = np.sort(np.abs(audio))
46
+
47
+ # If the maximum value is less than 0.1, scale the array to have a maximum of 0.1
48
+ if temp[-1] < 0.1:
49
+ scaling_factor = max(
50
+ temp[-1], 1e-3
51
+ ) # Prevent division by zero with a small constant
52
+ audio = audio / scaling_factor * 0.1
53
+
54
+ # Filter out values less than 0.01 from temp
55
+ temp = temp[temp > 0.01]
56
+ L = temp.shape[0] # Length of the filtered array
57
+
58
+ # If there are fewer than or equal to 10 significant values, return the audio without further processing
59
+ if L <= 10:
60
+ return audio
61
+
62
+ # Compute the average of the top 10% to 1% of values in temp
63
+ volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)])
64
+
65
+ # Normalize the audio to the target coefficient level, clamping the scale factor between 0.1 and 10
66
+ audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10)
67
+
68
+ # Ensure the maximum absolute value in the audio does not exceed 1
69
+ max_value = np.max(np.abs(audio))
70
+ if max_value > 1:
71
+ audio = audio / max_value
72
+
73
+ return audio
74
+
75
+
76
+ def load_audio(
77
+ adfile: Path,
78
+ sampling_rate: int = None,
79
+ length: int = None,
80
+ volume_normalize: bool = False,
81
+ segment_duration: int = None,
82
+ ) -> np.ndarray:
83
+ r"""Load audio file with target sampling rate and lsength
84
+
85
+ Args:
86
+ adfile (Path): path to audio file.
87
+ sampling_rate (int, optional): target sampling rate. Defaults to None.
88
+ length (int, optional): target audio length. Defaults to None.
89
+ volume_normalize (bool, optional): whether perform volume normalization. Defaults to False.
90
+ segment_duration (int): random select a segment with duration of {segment_duration}s.
91
+ Defualt to None which means the whole audio will be used.
92
+
93
+ Returns:
94
+ audio (np.ndarray): audio
95
+ """
96
+
97
+ audio, sr = soundfile.read(adfile)
98
+ if len(audio.shape) > 1:
99
+ audio = audio[:, 0]
100
+
101
+ if sampling_rate is not None and sr != sampling_rate:
102
+ audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ")
103
+ sr = sampling_rate
104
+
105
+ if segment_duration is not None:
106
+ seg_length = int(sr * segment_duration)
107
+ audio = random_select_audio_segment(audio, seg_length)
108
+
109
+ # Audio volume normalize
110
+ if volume_normalize:
111
+ audio = audio_volume_normalize(audio)
112
+ # check the audio length
113
+ if length is not None:
114
+ assert abs(audio.shape[0] - length) < 1000
115
+ if audio.shape[0] > length:
116
+ audio = audio[:length]
117
+ else:
118
+ audio = np.pad(audio, (0, int(length - audio.shape[0])))
119
+ return audio
120
+
121
+
122
+ def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray:
123
+ """get an audio segment given the length
124
+
125
+ Args:
126
+ audio (np.ndarray):
127
+ length (int): audio length = sampling_rate * duration
128
+ """
129
+ if audio.shape[0] < length:
130
+ audio = np.pad(audio, (0, int(length - audio.shape[0])))
131
+ start_index = random.randint(0, audio.shape[0] - length)
132
+ end_index = int(start_index + length)
133
+
134
+ return audio[start_index:end_index]
135
+
136
+
137
+ def audio_highpass_filter(audio, sample_rate, highpass_cutoff_freq):
138
+ """apply highpass fileter to audio
139
+
140
+ Args:
141
+ audio (np.ndarray):
142
+ sample_rate (ind):
143
+ highpass_cutoff_freq (int):
144
+ """
145
+
146
+ audio = torchaudio.functional.highpass_biquad(
147
+ torch.from_numpy(audio), sample_rate, cutoff_freq=highpass_cutoff_freq
148
+ )
149
+ return audio.numpy()
150
+
151
+
152
+ def stft(
153
+ x: torch.Tensor,
154
+ fft_size: int,
155
+ hop_size: int,
156
+ win_length: int,
157
+ window: str,
158
+ use_complex: bool = False,
159
+ ) -> torch.Tensor:
160
+ """Perform STFT and convert to magnitude spectrogram.
161
+ Args:
162
+ x (Tensor): Input signal tensor (B, T).
163
+ fft_size (int): FFT size.
164
+ hop_size (int): Hop size.
165
+ win_length (int): Window length.
166
+ window (str): Window function type.
167
+ Returns:
168
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
169
+ """
170
+
171
+ x_stft = torch.stft(
172
+ x, fft_size, hop_size, win_length, window.to(x.device), return_complex=True
173
+ )
174
+
175
+ # clamp is needed to avoid nan or inf
176
+ if not use_complex:
177
+ return torch.sqrt(
178
+ torch.clamp(x_stft.real**2 + x_stft.imag**2, min=1e-7, max=1e3)
179
+ ).transpose(2, 1)
180
+ else:
181
+ res = torch.cat([x_stft.real.unsqueeze(1), x_stft.imag.unsqueeze(1)], dim=1)
182
+ res = res.transpose(2, 3) # [B, 2, T, F]
183
+ return res
184
+
185
+
186
+ def detect_speech_boundaries(
187
+ wav: np.ndarray,
188
+ sample_rate: int,
189
+ window_duration: float = 0.1,
190
+ energy_threshold: float = 0.01,
191
+ margin_factor: int = 2
192
+ ) -> Tuple[int, int]:
193
+ """Detect the start and end points of speech in an audio signal using RMS energy.
194
+
195
+ Args:
196
+ wav: Input audio signal array with values in [-1, 1]
197
+ sample_rate: Audio sample rate in Hz
198
+ window_duration: Duration of detection window in seconds
199
+ energy_threshold: RMS energy threshold for speech detection
200
+ margin_factor: Factor to determine extra margin around detected boundaries
201
+
202
+ Returns:
203
+ tuple: (start_index, end_index) of speech segment
204
+
205
+ Raises:
206
+ ValueError: If the audio contains only silence
207
+ """
208
+ window_size = int(window_duration * sample_rate)
209
+ margin = margin_factor * window_size
210
+ step_size = window_size // 10
211
+
212
+ # Create sliding windows using stride tricks to avoid loops
213
+ windows = sliding_window_view(wav, window_size)[::step_size]
214
+
215
+ # Calculate RMS energy for each window
216
+ energy = np.sqrt(np.mean(windows ** 2, axis=1))
217
+ speech_mask = energy >= energy_threshold
218
+
219
+ if not np.any(speech_mask):
220
+ raise ValueError("No speech detected in audio (only silence)")
221
+
222
+ start = max(0, np.argmax(speech_mask) * step_size - margin)
223
+ end = min(len(wav), (len(speech_mask) - 1 - np.argmax(speech_mask[::-1])) * step_size + margin)
224
+
225
+ return start, end
226
+
227
+
228
+ def remove_silence_on_both_ends(
229
+ wav: np.ndarray,
230
+ sample_rate: int,
231
+ window_duration: float = 0.1,
232
+ volume_threshold: float = 0.01
233
+ ) -> np.ndarray:
234
+ """Remove silence from both ends of an audio signal.
235
+
236
+ Args:
237
+ wav: Input audio signal array
238
+ sample_rate: Audio sample rate in Hz
239
+ window_duration: Duration of detection window in seconds
240
+ volume_threshold: Amplitude threshold for silence detection
241
+
242
+ Returns:
243
+ np.ndarray: Audio signal with silence removed from both ends
244
+
245
+ Raises:
246
+ ValueError: If the audio contains only silence
247
+ """
248
+ start, end = detect_speech_boundaries(
249
+ wav,
250
+ sample_rate,
251
+ window_duration,
252
+ volume_threshold
253
+ )
254
+ return wav[start:end]
255
+
256
+
257
+
258
+ def hertz_to_mel(pitch: float) -> float:
259
+ """
260
+ Converts a frequency from the Hertz scale to the Mel scale.
261
+
262
+ Parameters:
263
+ - pitch: float or ndarray
264
+ Frequency in Hertz.
265
+
266
+ Returns:
267
+ - mel: float or ndarray
268
+ Frequency in Mel scale.
269
+ """
270
+ mel = 2595 * np.log10(1 + pitch / 700)
271
+ return mel
models/bicodec_tokenizer/utils/file.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang (w.xinshawn@gmail.com)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Description:
17
+ This script contains a collection of functions designed to handle various
18
+ file reading and writing operations. It provides utilities to read from files,
19
+ write data to files, and perform file manipulation tasks.
20
+ """
21
+
22
+
23
+ import os
24
+ import json
25
+ import json
26
+ import csv
27
+
28
+ from tqdm import tqdm
29
+ from typing import List, Dict, Any, Set, Union
30
+ from pathlib import Path
31
+ from omegaconf import OmegaConf, DictConfig
32
+
33
+
34
+ def resolve_symbolic_link(symbolic_link_path: Path) -> Path:
35
+ """
36
+ Resolves the absolute path of a symbolic link.
37
+
38
+ Args:
39
+ symbolic_link_path (Path): The path to the symbolic link.
40
+
41
+ Returns:
42
+ Path: The absolute path that the symbolic link points to.
43
+ """
44
+
45
+ link_directory = os.path.dirname(symbolic_link_path)
46
+ target_path_relative = os.readlink(symbolic_link_path)
47
+ return os.path.join(link_directory, target_path_relative)
48
+
49
+
50
+ def write_jsonl(metadata: List[dict], file_path: Path) -> None:
51
+ """Writes a list of dictionaries to a JSONL file.
52
+
53
+ Args:
54
+ metadata : List[dict]
55
+ A list of dictionaries, each representing a piece of meta.
56
+ file_path : Path
57
+ The file path to save the JSONL file
58
+
59
+ This function writes each dictionary in the list to a new line in the specified file.
60
+ """
61
+ with open(file_path, "w", encoding="utf-8") as f:
62
+ for meta in tqdm(metadata, desc="writing jsonl"):
63
+ # Convert dictionary to JSON string and write it to the file with a newline
64
+ json_str = json.dumps(meta, ensure_ascii=False) + "\n"
65
+ f.write(json_str)
66
+ print(f"jsonl saved to {file_path}")
67
+
68
+
69
+ def read_jsonl(file_path: Path) -> List[dict]:
70
+ """
71
+ Reads a JSONL file and returns a list of dictionaries.
72
+
73
+ Args:
74
+ file_path : Path
75
+ The path to the JSONL file to be read.
76
+
77
+ Returns:
78
+ List[dict]
79
+ A list of dictionaries parsed from each line of the JSONL file.
80
+ """
81
+ metadata = []
82
+ # Open the file for reading
83
+ with open(file_path, "r", encoding="utf-8") as f:
84
+ # Split the file into lines
85
+ lines = f.read().splitlines()
86
+ # Process each line
87
+ for line in lines:
88
+ # Convert JSON string back to dictionary and append to list
89
+ meta = json.loads(line)
90
+ metadata.append(meta)
91
+ # Return the list of metadata
92
+ return metadata
93
+
94
+ def read_json_as_jsonl(file_path: Path) -> List[dict]:
95
+ metadata = []
96
+ with open(file_path, 'r', encoding='utf-8') as infile:
97
+ data = json.load(infile)
98
+ for k in sorted(data.keys()):
99
+ meta = {'index': k}
100
+ meta.update(data[k])
101
+ metadata.append(meta)
102
+ return metadata
103
+
104
+
105
+
106
+ def decode_unicode_strings(meta: Dict[str, Any]) -> Dict[str, Any]:
107
+ processed_meta = {}
108
+ for k, v in meta.items():
109
+ if isinstance(v, str):
110
+ processed_meta[k] = v.encode("utf-8").decode("unicode_escape")
111
+ else:
112
+ processed_meta[k] = v
113
+ return processed_meta
114
+
115
+
116
+ def load_config(config_path: Path) -> DictConfig:
117
+ """Loads a configuration file and optionally merges it with a base configuration.
118
+
119
+ Args:
120
+ config_path (Path): Path to the configuration file.
121
+ """
122
+ # Load the initial configuration from the given path
123
+ config = OmegaConf.load(config_path)
124
+
125
+ # Check if there is a base configuration specified and merge if necessary
126
+ if config.get("base_config", None) is not None:
127
+ base_config = OmegaConf.load(config["base_config"])
128
+ config = OmegaConf.merge(base_config, config)
129
+
130
+ return config
131
+
132
+
133
+
134
+ def jsonl_to_csv(jsonl_file_path: str, csv_file_path: str) -> None:
135
+ """
136
+ Converts a JSONL file to a CSV file.
137
+
138
+ This function reads a JSONL file, determines all unique keys present in the file,
139
+ and writes the data to a CSV file with columns for all these keys.
140
+ """
141
+
142
+ all_keys = set()
143
+ data_rows = []
144
+
145
+ # Read the JSONL file once to extract keys and collect data
146
+ with open(jsonl_file_path, 'r') as file:
147
+ for line in file:
148
+ data = json.loads(line.strip())
149
+ data_rows.append(data)
150
+ all_keys.update(data.keys())
151
+
152
+ # Convert the set of keys to a sorted list for consistent column order
153
+ sorted_keys = sorted(all_keys)
154
+
155
+ # Write the data to a CSV file
156
+ with open(csv_file_path, 'w', newline='') as csvfile:
157
+ writer = csv.DictWriter(csvfile, fieldnames=sorted_keys)
158
+
159
+ # Write the header row
160
+ writer.writeheader()
161
+
162
+ # Write each row of data
163
+ for data in data_rows:
164
+ writer.writerow(data)
165
+
166
+ print(f"CSV file has been created at {csv_file_path}")
167
+
168
+
169
+ def save_metadata(data, filename, headers=None):
170
+ """
171
+ Save metadata to a file.
172
+
173
+ Args:
174
+ data (list of dict): Metadata to be saved.
175
+ filename (str): Name of the file to save the metadata.
176
+ headers (list of str): The order of column names to be saved; defaults to the keys from the first dictionary in data if not provided.
177
+ """
178
+ # Set headers to keys from the first dictionary in data if not explicitly provided
179
+ if headers is None:
180
+ headers = list(data[0].keys())
181
+
182
+ with open(filename, "w", encoding="utf-8") as file:
183
+ # Write the headers to the file
184
+ file.write("|".join(headers) + "\n")
185
+ for entry in data:
186
+ # Retrieve values in the order of headers, replacing any '|' characters with a space to prevent formatting errors
187
+ formatted_values = [str(entry.get(key, "")).replace("|", " ") for key in headers]
188
+ # Write the formatted values to the file
189
+ file.write("|".join(formatted_values) + "\n")
190
+
191
+
192
+ def read_metadata(filename, headers=None):
193
+ """
194
+ Read metadata from a file.
195
+
196
+ Args:
197
+ filename (str): The file from which to read the metadata.
198
+
199
+ Returns:
200
+ list of dict: The metadata read from the file.
201
+ list of str: The headers used in the file.
202
+ """
203
+ with open(filename, "r", encoding="utf-8") as file:
204
+ lines = file.readlines()
205
+
206
+ data = []
207
+ # Set headers from the first line of the file if not provided
208
+ if headers is None:
209
+ headers = lines[0].strip().split("|")
210
+ lines = lines[1:]
211
+
212
+ for line in lines:
213
+ line = line.strip()
214
+ # Skip empty lines
215
+ if not line:
216
+ continue
217
+ # Split the line by '|' and pair with headers to form a dictionary
218
+ entry_data = dict(zip(headers, line.split("|")))
219
+ data.append(entry_data)
220
+
221
+ return data, headers
models/bicodec_tokenizer/utils/parse_options.sh ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
4
+ # Arnab Ghoshal, Karel Vesely
5
+
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13
+ # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
14
+ # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
15
+ # MERCHANTABLITY OR NON-INFRINGEMENT.
16
+ # See the Apache 2 License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+
20
+ # Parse command-line options.
21
+ # To be sourced by another script (as in ". parse_options.sh").
22
+ # Option format is: --option-name arg
23
+ # and shell variable "option_name" gets set to value "arg."
24
+ # The exception is --help, which takes no arguments, but prints the
25
+ # $help_message variable (if defined).
26
+
27
+
28
+ ###
29
+ ### The --config file options have lower priority to command line
30
+ ### options, so we need to import them first...
31
+ ###
32
+
33
+ # Now import all the configs specified by command-line, in left-to-right order
34
+ # for ((argpos=1; argpos<$#; argpos++)); do
35
+ # if [ "${!argpos}" == "--config" ]; then
36
+ # argpos_plus1=$((argpos+1))
37
+ # config=${!argpos_plus1}
38
+ # [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
39
+ # . $config # source the config file.
40
+ # fi
41
+ # done
42
+
43
+
44
+ ###
45
+ ### No we process the command line options
46
+ ###
47
+ while true; do
48
+ [ -z "${1:-}" ] && break; # break if there are no arguments
49
+ case "$1" in
50
+ # If the enclosing script is called with --help option, print the help
51
+ # message and exit. Scripts should put help messages in $help_message
52
+ --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
53
+ else printf "$help_message\n" 1>&2 ; fi;
54
+ exit 0 ;;
55
+ --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
56
+ exit 1 ;;
57
+ # If the first command-line argument begins with "--" (e.g. --foo-bar),
58
+ # then work out the variable name as $name, which will equal "foo_bar".
59
+ --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
60
+ # Next we test whether the variable in question is undefned-- if so it's
61
+ # an invalid option and we die. Note: $0 evaluates to the name of the
62
+ # enclosing script.
63
+ # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
64
+ # is undefined. We then have to wrap this test inside "eval" because
65
+ # foo_bar is itself inside a variable ($name).
66
+ eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
67
+
68
+ oldval="`eval echo \\$$name`";
69
+ # Work out whether we seem to be expecting a Boolean argument.
70
+ if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
71
+ was_bool=true;
72
+ else
73
+ was_bool=false;
74
+ fi
75
+
76
+ # Set the variable to the right value-- the escaped quotes make it work if
77
+ # the option had spaces, like --cmd "queue.pl -sync y"
78
+ eval $name=\"$2\";
79
+
80
+ # Check that Boolean-valued arguments are really Boolean.
81
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
82
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
83
+ exit 1;
84
+ fi
85
+ shift 2;
86
+ ;;
87
+ *) break;
88
+ esac
89
+ done
90
+
91
+
92
+ # Check for an empty argument to the --cmd option, which can easily occur as a
93
+ # result of scripting errors.
94
+ [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
95
+
96
+
97
+ true; # so this script returns exit code 0.
models/bicodec_tokenizer/utils/token_parser.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK_TOKEN_MAP = {
2
+ "vc": "<|task_vc|>",
3
+ "tts": "<|task_tts|>",
4
+ "asr": "<|task_asr|>",
5
+ "s2s": "<|task_s2s|>",
6
+ "t2s": "<|task_t2s|>",
7
+ "understand": "<|task_understand|>",
8
+ "caption": "<|task_cap|>",
9
+ "controllable_tts": "<|task_controllable_tts|>",
10
+ "prompt_tts": "<|task_prompt_tts|>",
11
+ "speech_edit": "<|task_edit|>",
12
+ }
13
+
14
+ LEVELS_MAP = {
15
+ "very_low": 0,
16
+ "low": 1,
17
+ "moderate": 2,
18
+ "high": 3,
19
+ "very_high": 4,
20
+ }
21
+
22
+ LEVELS_MAP_UI = {
23
+ 1: 'very_low',
24
+ 2: 'low',
25
+ 3: 'moderate',
26
+ 4: 'high',
27
+ 5: 'very_high'
28
+ }
29
+
30
+ GENDER_MAP = {
31
+ "female": 0,
32
+ "male": 1,
33
+ }
34
+
35
+ AGE_MAP = {"Child": 0, "Teenager": 1, "Youth-Adult": 2, "Middle-aged": 3, "Elderly": 4}
36
+
37
+ EMO_MAP = {
38
+ "UNKNOWN": 0,
39
+ "NEUTRAL": 1,
40
+ "ANGRY": 2,
41
+ "HAPPY": 3,
42
+ "SAD": 4,
43
+ "FEARFUL": 5,
44
+ "DISGUSTED": 6,
45
+ "SURPRISED": 7,
46
+ "SARCASTIC": 8,
47
+ "EXCITED": 9,
48
+ "SLEEPY": 10,
49
+ "CONFUSED": 11,
50
+ "EMPHASIS": 12,
51
+ "LAUGHING": 13,
52
+ "SINGING": 14,
53
+ "WORRIED": 15,
54
+ "WHISPER": 16,
55
+ "ANXIOUS": 17,
56
+ "NO-AGREEMENT": 18,
57
+ "APOLOGETIC": 19,
58
+ "CONCERNED": 20,
59
+ "ENUNCIATED": 21,
60
+ "ASSERTIVE": 22,
61
+ "ENCOURAGING": 23,
62
+ "CONTEMPT": 24,
63
+ }
64
+
65
+
66
+ class TokenParser:
67
+ """Turn label to special token"""
68
+
69
+ def __init__(self):
70
+ pass
71
+
72
+ """Parse the attributes of a person."""
73
+
74
+ def __init__(self):
75
+ pass
76
+
77
+ @staticmethod
78
+ def age(age: str) -> str:
79
+ """Turn age token."""
80
+ age_id = AGE_MAP[age]
81
+ return f"<|age_{age_id}|>"
82
+
83
+ @staticmethod
84
+ def gender(gender: str) -> str:
85
+ """Turn gender token."""
86
+ gender_id = GENDER_MAP[gender]
87
+ return f"<|gender_{gender_id}|>"
88
+
89
+ @staticmethod
90
+ def mel_value(mel: int):
91
+ """Turn special token of mel scale pitch."""
92
+ mel = max(0, int(mel))
93
+ mel = min(1000, int(mel))
94
+ return f"<|pitch_value_{mel}|>"
95
+
96
+ @staticmethod
97
+ def mel_level(level: str):
98
+ """Turn special token of mel level."""
99
+ level_tag = LEVELS_MAP[level]
100
+ return f"<|pitch_label_{level_tag}|>"
101
+
102
+ @staticmethod
103
+ def pitch_var_value(pitch_std: int):
104
+ """Turn special token of pitch_std value."""
105
+ assert isinstance(pitch_std, int)
106
+ pitch_std = max(0, int(pitch_std))
107
+ pitch_std = min(10, int(pitch_std))
108
+ return f"<|pitch_var_value_{pitch_std}|>"
109
+
110
+ @staticmethod
111
+ def pitch_var_level(level: str):
112
+ """Turn special token of pitch std level."""
113
+ level_tag = LEVELS_MAP[level]
114
+ return f"<|pitch_var_label_{level_tag}|>"
115
+
116
+ @staticmethod
117
+ def loudness_value(loudness: int):
118
+ """Turn special toak of loudness value [0, 30]"""
119
+ assert loudness >= 0
120
+ loudness = max(0, int(loudness))
121
+ loudness = min(30, int(loudness))
122
+ return f"<|loudness_value_{loudness}|>"
123
+
124
+ @staticmethod
125
+ def loudness_level(level: str):
126
+ """Turn special token of loudness level."""
127
+ level_tag = LEVELS_MAP[level]
128
+ return f"<|loudness_label_{level_tag}|>"
129
+
130
+ @staticmethod
131
+ def speed_value(speed: int):
132
+ """Turn special token of speed value."""
133
+ speed = max(0, int(speed))
134
+ speed = min(10, int(speed))
135
+ return f"<|speed_value_{speed}|>"
136
+
137
+ @staticmethod
138
+ def speed_level(level: str):
139
+ """Turn special token of speed level."""
140
+ level_tag = LEVELS_MAP[level]
141
+ return f"<|speed_label_{level_tag}|>"
142
+
143
+ @staticmethod
144
+ def task(task: str) -> str:
145
+ """Turn special token of task."""
146
+ assert task in TASK_TOKEN_MAP.keys()
147
+
148
+ return TASK_TOKEN_MAP[task]
149
+
150
+ @staticmethod
151
+ def emotion(emotion: str):
152
+ emo_id = EMO_MAP[emotion]
153
+
154
+ return f"<|emotion_{emo_id}|>"
155
+
156
+
157
+ # test
158
+ if __name__ == "__main__":
159
+ from transformers import AutoTokenizer
160
+
161
+ tokenizer = AutoTokenizer.from_pretrained(
162
+ "/aifs4su/xinshengwang/code/StyleCraft/tokenizer/stylecraft-bicodec-pitch-loudness-speed-emotion-tokenizer"
163
+ )
164
+
165
+ tasks = ["tts", "tts", "understand", "controllable_tts", "prompt_tts"]
166
+ ages = ["Child", "Teenager", "Youth-Adult", "Middle-aged", "Elderly"]
167
+ genders = ["female", "female", "female", "male", "male"]
168
+ mels = [100, 200, 300, 400, 500]
169
+ mel_levels = ["very_low", "low", "moderate", "high", "very_high"]
170
+ loudnesses = [1, 10, 23, 19, 30]
171
+ loudness_levels = ["very_low", "low", "moderate", "high", "very_high"]
172
+ emotions = ["UNKNOWN", "NEUTRAL", "ANGRY", "HAPPY", "SAD"]
173
+
174
+ for i in range(5):
175
+ task = TokenParser.task(tasks[i])
176
+ age = TokenParser.age(ages[i])
177
+ gender = TokenParser.gender(genders[i])
178
+ mel = TokenParser.mel_value(mels[i])
179
+ mel_level = TokenParser.mel_level(mel_levels[i])
180
+ loudness = TokenParser.loudness_value(loudnesses[i])
181
+ loudness_level = TokenParser.loudness_level(loudness_levels[i])
182
+ emotion = TokenParser.emotion(emotions[i])
183
+ inputs = [task, age, gender, mel, mel_level, loudness, loudness_level, emotion]
184
+ inputs = "".join(inputs)
185
+ ids = tokenizer.encode(inputs, add_special_tokens=False)
186
+ print(ids)
187
+ print("decode", tokenizer.decode(ids))
models/glm_speech_tokenizer/__init__.py ADDED
File without changes
models/glm_speech_tokenizer/batch_processor.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Time :2024/11/17 15:33
3
+ # Author :Hui Huang
4
+ import asyncio
5
+ import uuid
6
+ from typing import Callable, List, Any, Awaitable, Tuple
7
+ from asyncio import Queue
8
+
9
+
10
+ class BatchProcessor:
11
+ """Batch Processor for handling asynchronous requests in batches.
12
+
13
+ This class manages a queue of requests and processes them in batches
14
+ using multiple worker tasks.
15
+
16
+ Attributes:
17
+ processing_function (Callable[[List[Any]], Awaitable[List[Any]]]):
18
+ The function used for processing requests in batches.
19
+ num_workers (int): The number of worker tasks to process requests.
20
+ batch_size (int): The maximum number of requests to process in a single batch.
21
+ request_queue (Queue): The queue holding incoming requests.
22
+ loop (asyncio.AbstractEventLoop): The event loop used to create worker tasks.
23
+ worker_tasks (List[asyncio.Task]): The list of worker tasks.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ processing_function: Callable[[List[Any]], Awaitable[List[Any]]],
29
+ num_workers: int,
30
+ batch_size: int,
31
+ wait_timeout: float = 0.05
32
+ ) -> None:
33
+ """Initialize the BatchProcessor with the given processing function, number of workers, and batch size.
34
+
35
+ Args:
36
+ processing_function (Callable[[List[Any]], Awaitable[List[Any]]]):
37
+ The function used for processing requests in batches.
38
+ num_workers (int): The number of worker tasks to process requests.
39
+ batch_size (int): The maximum number of requests to process in a single batch.
40
+ """
41
+ self.processing_function = processing_function
42
+ self.num_workers = num_workers
43
+ self.batch_size = batch_size
44
+ self.wait_timeout = wait_timeout
45
+ self.request_queue: Queue = Queue()
46
+ self.loop = asyncio.get_running_loop()
47
+ self.worker_tasks = [
48
+ self.loop.create_task(self.batch_processor(i)) for i in range(num_workers)
49
+ ]
50
+ # Wait until all worker tasks are started
51
+ self.loop.create_task(self._log_workers_started())
52
+
53
+ async def _log_workers_started(self):
54
+ await asyncio.sleep(0) # Yield control to ensure workers have started
55
+
56
+ async def batch_processor(self, worker_id: int):
57
+ """Worker task that processes requests from the queue in batches.
58
+
59
+ Args:
60
+ worker_id (int): The identifier for the worker task.
61
+ """
62
+
63
+ while True:
64
+ requests: List[Tuple[Any, asyncio.Future]] = []
65
+ try:
66
+ while len(requests) < self.batch_size:
67
+ request = await asyncio.wait_for(
68
+ self.request_queue.get(), timeout=self.wait_timeout
69
+ )
70
+ requests.append(request)
71
+ except asyncio.TimeoutError:
72
+ pass
73
+
74
+ if requests:
75
+ all_requests = [
76
+ req[0] for req in requests
77
+ ] # Extract the actual input data from each request tuple
78
+ futures = [req[1] for req in requests] # Extract the futures to resolve
79
+ try:
80
+ results = await self.processing_function(all_requests)
81
+
82
+ for (future, result) in zip(futures, results):
83
+ future.set_result(result)
84
+ except Exception as e:
85
+ for future in futures:
86
+ future.set_exception(e)
87
+
88
+ async def add_request(self, single_input: Any):
89
+ """Add a new request to the queue.
90
+
91
+ Args:
92
+ single_input (Any): The input data for processing.
93
+ """
94
+ # loop = asyncio.get_running_loop()
95
+ future = self.loop.create_future()
96
+ self.request_queue.put_nowait((single_input, future))
97
+ return future
98
+
99
+ async def shutdown(self):
100
+ """Shutdown the batch processor by cancelling all worker tasks."""
101
+ for task in self.worker_tasks:
102
+ task.cancel()
103
+ try:
104
+ await task
105
+ except asyncio.CancelledError:
106
+ print("Worker task cancelled.")
107
+
108
+
109
+ class AsyncBatchEngine:
110
+
111
+ def __init__(
112
+ self,
113
+ processing_function: Callable[[List[Any]], Awaitable[List[Any]]],
114
+ batch_size: int = 32,
115
+ wait_timeout: float = 0.01,
116
+ ):
117
+ """
118
+ Initialize the AsyncBatchEngine with a processing function, number of workers, and batch size.
119
+
120
+ Args:
121
+ processing_function (Callable[[List[Any]], Awaitable[List[Any]]]): The batch processing function.
122
+ batch_size (int): The maximum number of requests to process in a single batch.
123
+ """
124
+ self._processing_function = processing_function
125
+ self._batch_size = batch_size
126
+ self._is_running = False
127
+ self._batch_processor = None
128
+ self._wait_timeout = wait_timeout
129
+
130
+ async def start(self):
131
+ """Start the engine by initializing the batch processor and worker tasks."""
132
+ if self._is_running:
133
+ return
134
+
135
+ self._batch_processor = BatchProcessor(
136
+ processing_function=self._processing_function,
137
+ batch_size=self._batch_size,
138
+ wait_timeout=self._wait_timeout,
139
+ num_workers=1
140
+ )
141
+ self._is_running = True
142
+
143
+ async def stop(self):
144
+ """Stop the engine by shutting down the batch processor and worker tasks."""
145
+ self._check_running()
146
+ self._is_running = False
147
+ if self._batch_processor is not None:
148
+ await self._batch_processor.shutdown()
149
+
150
+ def _check_running(self):
151
+ """Check if the engine is running.
152
+
153
+ Raises:
154
+ ValueError: If the engine is not running.
155
+ """
156
+ if not self._is_running:
157
+ raise ValueError(
158
+ "The engine is not running. "
159
+ "You must start the engine before using it."
160
+ )
161
+
162
+ async def add_request(self, single_input: Any, request_id: str = None) -> dict:
163
+ """Asynchronously add a request to be processed.
164
+
165
+ Args:
166
+ single_input (Any): The input data for processing.
167
+ request_id (str): Optional request identifier to avoid data mix-up.
168
+
169
+ Raises:
170
+ ValueError: If the engine is not running when this method is called.
171
+ """
172
+ if not self._is_running:
173
+ await self.start()
174
+
175
+ if request_id is None:
176
+ request_id = str(uuid.uuid4()) # Assign a unique ID if not provided
177
+ future = await self._batch_processor.add_request(single_input=single_input) # type: ignore
178
+ result = await future
179
+ return dict(
180
+ request_id=request_id,
181
+ feature=result
182
+ )
models/glm_speech_tokenizer/configuration_whisper.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import WhisperConfig
2
+
3
+
4
+ class WhisperVQConfig(WhisperConfig):
5
+ def __init__(self,
6
+ pooling_kernel_size=None,
7
+ pooling_type="max",
8
+ pooling_position=0,
9
+ quantize_vocab_size=None,
10
+ quantize_position=16,
11
+ quantize_commit_coefficient=0.25,
12
+ quantize_loss_scale=1.0,
13
+ quantize_ema_decay=None,
14
+ quantize_restart_interval=None,
15
+ quantize_encoder_only=False,
16
+ quantize_causal_encoder=False,
17
+ quantize_causal_block_size=None,
18
+ skip_language_detection=False,
19
+ encoder_causal_attention=False,
20
+ encoder_causal_convolution=False,
21
+ **kwargs):
22
+ self.pooling_kernel_size = pooling_kernel_size
23
+ self.pooling_type = pooling_type
24
+ self.pooling_position = pooling_position
25
+ self.quantize_vocab_size = quantize_vocab_size
26
+ self.quantize_position = quantize_position
27
+ self.quantize_commit_coefficient = quantize_commit_coefficient
28
+ self.quantize_loss_scale = quantize_loss_scale
29
+ self.quantize_ema_decay = quantize_ema_decay
30
+ self.quantize_restart_interval = quantize_restart_interval
31
+ self.quantize_encoder_only = quantize_encoder_only
32
+ self.quantize_causal_encoder = quantize_causal_encoder
33
+ self.quantize_causal_block_size = quantize_causal_block_size
34
+ self.skip_language_detection = skip_language_detection
35
+ self.encoder_causal_attention = encoder_causal_attention
36
+ self.encoder_causal_convolution = encoder_causal_convolution
37
+ super().__init__(**kwargs)
models/glm_speech_tokenizer/generation_whisper.py ADDED
@@ -0,0 +1,1828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import copy
16
+ import math
17
+ import warnings
18
+ import zlib
19
+ from typing import Callable, Iterator, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from torch import nn
25
+
26
+ from transformers.cache_utils import EncoderDecoderCache
27
+
28
+ from transformers.generation.configuration_utils import GenerationConfig
29
+ from transformers.generation.logits_process import (
30
+ LogitsProcessorList,
31
+ SuppressTokensAtBeginLogitsProcessor,
32
+ SuppressTokensLogitsProcessor,
33
+ WhisperNoSpeechDetection,
34
+ WhisperTimeStampLogitsProcessor,
35
+ )
36
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
37
+ from transformers.modeling_outputs import BaseModelOutput
38
+ from transformers.utils import logging
39
+ from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
40
+
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+
45
+ def _median_filter(inputs: torch.Tensor, filter_width: int) -> torch.Tensor:
46
+ """
47
+ Applies a median filter of width `filter_width` along the last dimension of the input.
48
+
49
+ The `inputs` tensor is assumed to be 3- or 4-dimensional.
50
+ """
51
+ if filter_width <= 0 or filter_width % 2 != 1:
52
+ raise ValueError("`filter_width` should be an odd number")
53
+
54
+ pad_width = filter_width // 2
55
+ if inputs.shape[-1] <= pad_width:
56
+ return inputs
57
+
58
+ # Pad the left and right edges.
59
+ inputs = nn.functional.pad(inputs, (pad_width, pad_width, 0, 0), mode="reflect")
60
+
61
+ # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
62
+ result = inputs.unfold(-1, filter_width, 1).sort()[0][..., pad_width]
63
+ return result
64
+
65
+
66
+ def _dynamic_time_warping(matrix: np.ndarray):
67
+ """
68
+ Measures similarity between two temporal sequences: the input audio and the output tokens. Used to generate
69
+ token-level timestamps.
70
+ """
71
+ output_length, input_length = matrix.shape
72
+ cost = np.ones((output_length + 1, input_length + 1), dtype=np.float32) * np.inf
73
+ trace = -np.ones((output_length + 1, input_length + 1), dtype=np.float32)
74
+
75
+ cost[0, 0] = 0
76
+ for j in range(1, input_length + 1):
77
+ for i in range(1, output_length + 1):
78
+ c0 = cost[i - 1, j - 1]
79
+ c1 = cost[i - 1, j]
80
+ c2 = cost[i, j - 1]
81
+
82
+ if c0 < c1 and c0 < c2:
83
+ c, t = c0, 0
84
+ elif c1 < c0 and c1 < c2:
85
+ c, t = c1, 1
86
+ else:
87
+ c, t = c2, 2
88
+
89
+ cost[i, j] = matrix[i - 1, j - 1] + c
90
+ trace[i, j] = t
91
+
92
+ # backtrace
93
+ i = trace.shape[0] - 1
94
+ j = trace.shape[1] - 1
95
+ trace[0, :] = 2
96
+ trace[:, 0] = 1
97
+
98
+ text_indices = []
99
+ time_indices = []
100
+ while i > 0 or j > 0:
101
+ text_indices.append(i - 1)
102
+ time_indices.append(j - 1)
103
+ if trace[i, j] == 0:
104
+ i -= 1
105
+ j -= 1
106
+ elif trace[i, j] == 1:
107
+ i -= 1
108
+ elif trace[i, j] == 2:
109
+ j -= 1
110
+ else:
111
+ raise RuntimeError(
112
+ f"Internal error in dynamic time warping. Unexpected trace[{i}, {j}]. Please file a bug report."
113
+ )
114
+
115
+ text_indices = np.array(text_indices)[::-1]
116
+ time_indices = np.array(time_indices)[::-1]
117
+ return text_indices, time_indices
118
+
119
+
120
+ def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name):
121
+ if logits_processor is not None:
122
+ logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None)
123
+ if logit_processor:
124
+ return getattr(logit_processor, attribute_name, None)
125
+ return None
126
+
127
+
128
+ def _pad_to_max_length(
129
+ current_segments,
130
+ pad_token_id,
131
+ device,
132
+ padding_side="right",
133
+ padding="longest",
134
+ bos_token_tensor=None,
135
+ cut_off_length=None,
136
+ ):
137
+ max_total_length = 0
138
+ sequences = []
139
+
140
+ if padding_side not in ["right", "left"]:
141
+ raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")
142
+
143
+ if padding not in ["longest", "max_length"]:
144
+ raise ValueError(f"`padding` must be either 'longest' or 'max_length', not {padding}")
145
+ elif padding == "max_length" and cut_off_length is None:
146
+ raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")
147
+
148
+ for current_segment_list in current_segments:
149
+ if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
150
+ sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
151
+
152
+ if cut_off_length is not None:
153
+ sequence = sequence[-cut_off_length:]
154
+
155
+ if bos_token_tensor is not None:
156
+ sequence = torch.cat([bos_token_tensor, sequence])
157
+
158
+ sequences.append(sequence)
159
+ max_total_length = max(max_total_length, len(sequences[-1]))
160
+ elif bos_token_tensor is not None:
161
+ sequences.append(bos_token_tensor)
162
+ else:
163
+ sequences.append(torch.tensor([], device=device))
164
+
165
+ max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
166
+ for i in range(len(current_segments)):
167
+ pad_length = max_total_length - len(sequences[i])
168
+ pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
169
+ sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)
170
+
171
+ sequences = torch.stack(sequences, dim=0)
172
+ return sequences
173
+
174
+
175
+ class WhisperGenerationMixin:
176
+ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None):
177
+ """
178
+ Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to
179
+ map each output token to a position in the input audio. If `num_frames` is specified, the encoder-decoder
180
+ cross-attentions will be cropped before applying DTW.
181
+
182
+ Returns:
183
+ tensor containing the timestamps in seconds for each predicted token
184
+ """
185
+ # Create a list with `decoder_layers` elements, each a tensor of shape
186
+ # (batch size, attention_heads, output length, input length).
187
+ cross_attentions = []
188
+ for i in range(self.config.decoder_layers):
189
+ cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2))
190
+
191
+ # Select specific cross-attention layers and heads. This is a tensor
192
+ # of shape (batch size, num selected, output length, input length).
193
+ weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
194
+ weights = weights.permute([1, 0, 2, 3])
195
+
196
+ weight_length = None
197
+
198
+ if "beam_indices" in generate_outputs:
199
+ # If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths
200
+ # since the beam search strategy chooses the most probable sequences at the end of the search.
201
+ # In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length
202
+ weight_length = (generate_outputs.beam_indices != -1).sum(-1).max()
203
+ weights = weights[:, :, :weight_length]
204
+
205
+ # If beam index is still -1, it means that the associated token id is EOS
206
+ # We need to replace the index with 0 since index_select gives an error if any of the indexes is -1.
207
+ beam_indices = generate_outputs.beam_indices[:, :weight_length]
208
+ beam_indices = beam_indices.masked_fill(beam_indices == -1, 0)
209
+
210
+ # Select the cross attention from the right beam for each output sequences
211
+ weights = torch.stack(
212
+ [
213
+ torch.index_select(weights[:, :, i, :], dim=0, index=beam_indices[:, i])
214
+ for i in range(beam_indices.shape[1])
215
+ ],
216
+ dim=2,
217
+ )
218
+
219
+ # make sure timestamps are as long as weights
220
+ input_length = weight_length or cross_attentions[0].shape[2]
221
+ timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)[:, : input_length + 1]
222
+ batch_size = timestamps.shape[0]
223
+
224
+ if num_frames is not None:
225
+ # two cases:
226
+ # 1. num_frames is the same for each sample -> compute the DTW matrix for each sample in parallel
227
+ # 2. num_frames is different, compute the DTW matrix for each sample sequentially
228
+
229
+ # we're using np.unique because num_frames can be int/list/tuple
230
+ if isinstance(num_frames, int):
231
+ weights = weights[..., : num_frames // 2]
232
+
233
+ elif isinstance(num_frames, (list, tuple, np.ndarray)) and len(np.unique(num_frames)) == 1:
234
+ weights = weights[..., : num_frames[0] // 2]
235
+
236
+ elif isinstance(num_frames, (torch.Tensor)) and len(torch.unique(num_frames)) == 1:
237
+ weights = weights[..., : num_frames[0] // 2]
238
+
239
+ else:
240
+ # num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences
241
+ repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames)
242
+ num_frames = np.repeat(num_frames, repeat_time)
243
+
244
+ if num_frames is None or isinstance(num_frames, int):
245
+ # Normalize and smoothen the weights.
246
+ std = torch.std(weights, dim=-2, keepdim=True, unbiased=False)
247
+ mean = torch.mean(weights, dim=-2, keepdim=True)
248
+ weights = (weights - mean) / std
249
+ weights = _median_filter(weights, self.config.median_filter_width)
250
+
251
+ # Average the different cross-attention heads.
252
+ weights = weights.mean(dim=1)
253
+
254
+ # Perform dynamic time warping on each element of the batch.
255
+ for batch_idx in range(batch_size):
256
+ if num_frames is not None and isinstance(num_frames, (tuple, list, np.ndarray, torch.Tensor)):
257
+ matrix = weights[batch_idx, ..., : num_frames[batch_idx] // 2]
258
+
259
+ # Normalize and smoothen the weights.
260
+ std = torch.std(matrix, dim=-2, keepdim=True, unbiased=False)
261
+ mean = torch.mean(matrix, dim=-2, keepdim=True)
262
+ matrix = (matrix - mean) / std
263
+ matrix = _median_filter(matrix, self.config.median_filter_width)
264
+
265
+ # Average the different cross-attention heads.
266
+ matrix = matrix.mean(dim=0)
267
+ else:
268
+ matrix = weights[batch_idx]
269
+
270
+ text_indices, time_indices = _dynamic_time_warping(-matrix.cpu().double().numpy())
271
+ jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
272
+ jump_times = time_indices[jumps] * time_precision
273
+ timestamps[batch_idx, 1:] = torch.tensor(jump_times)
274
+
275
+ return timestamps
276
+
277
+ def generate(
278
+ self,
279
+ input_features: Optional[torch.Tensor] = None,
280
+ generation_config: Optional[GenerationConfig] = None,
281
+ logits_processor: Optional[LogitsProcessorList] = None,
282
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
283
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
284
+ synced_gpus: bool = False,
285
+ return_timestamps: Optional[bool] = None,
286
+ task: Optional[str] = None,
287
+ language: Optional[Union[str, List[str]]] = None,
288
+ is_multilingual: Optional[bool] = None,
289
+ prompt_ids: Optional[torch.Tensor] = None,
290
+ prompt_condition_type: Optional[str] = None, # first-segment, all-segments
291
+ condition_on_prev_tokens: Optional[bool] = None,
292
+ temperature: Optional[Union[float, Tuple[float, ...]]] = None,
293
+ compression_ratio_threshold: Optional[float] = None,
294
+ logprob_threshold: Optional[float] = None,
295
+ no_speech_threshold: Optional[float] = None,
296
+ num_segment_frames: Optional[int] = None,
297
+ attention_mask: Optional[torch.Tensor] = None,
298
+ time_precision: float = 0.02,
299
+ return_token_timestamps: Optional[bool] = None,
300
+ return_segments: bool = False,
301
+ return_dict_in_generate: Optional[bool] = None,
302
+ **kwargs,
303
+ ):
304
+ """
305
+ Transcribes or translates log-mel input features to a sequence of auto-regressively generated token ids.
306
+
307
+ <Tip warning={true}>
308
+
309
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
310
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
311
+ parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
312
+
313
+ For an overview of generation strategies and code examples, check out the [following
314
+ guide](./generation_strategies).
315
+
316
+ </Tip>
317
+
318
+ Parameters:
319
+ input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):
320
+ Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by
321
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
322
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
323
+ [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
324
+ tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.
325
+ generation_config (`~generation.GenerationConfig`, *optional*):
326
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
327
+ passed to generate matching the attributes of `generation_config` will override them. If
328
+ `generation_config` is not provided, the default will be used, which had the following loading
329
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
330
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
331
+ default values, whose documentation should be checked to parameterize generation.
332
+ logits_processor (`LogitsProcessorList`, *optional*):
333
+ Custom logits processors that complement the default logits processors built from arguments and
334
+ generation config. If a logit processor is passed that is already created with the arguments or a
335
+ generation config an error is thrown. This feature is intended for advanced users.
336
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
337
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
338
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
339
+ generation config an error is thrown. This feature is intended for advanced users.
340
+ prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
341
+ If provided, this function constraints the beam search to allowed tokens only at each step. If not
342
+ provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
343
+ `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
344
+ on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
345
+ for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
346
+ Retrieval](https://arxiv.org/abs/2010.00904).
347
+ synced_gpus (`bool`, *optional*, defaults to `False`):
348
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
349
+ return_timestamps (`bool`, *optional*):
350
+ Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
351
+ task (`str`, *optional*):
352
+ Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids`
353
+ will be updated accordingly.
354
+ language (`str` or list of `str`, *optional*):
355
+ Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. For
356
+ batched generation, a list of language tokens can be passed. You can find all the possible language
357
+ tokens in the `model.generation_config.lang_to_id` dictionary.
358
+ is_multilingual (`bool`, *optional*):
359
+ Whether or not the model is multilingual.
360
+ prompt_ids (`torch.Tensor`, *optional*):
361
+ Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is
362
+ provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
363
+ transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
364
+ correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
365
+ prompt_condition_type (`str`, *optional*):
366
+ Only relevant for long-form transcription. Condition type of `prompt_ids`. 'first-segment' means only the first segment is conditioned on `prompt_ids`. 'all-segments' means each segment is conditioned on `prompt_ids`. Make sure to enable `condition_on_prev_tokens` for 'all-segments'.
367
+ Defaults to 'first-segment'. For short-term transcription only 'first-segment' is possible.
368
+ condition_on_prev_tokens (`bool`, *optional*):
369
+ Only relevant for long-form transcription. Whether to condition each segment on the previous segment.
370
+ As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
371
+ performance.
372
+ temperature (`float` or list of `float`, *optional*):
373
+ The temperature to be used for generation. Passing a single `float` value and `do_sample=True` activates
374
+ generation using sampling. For long-form transcription, temperature fallback can be activated by passing
375
+ a list of float values such as (0.0, 0.2, 0.4, 0.6, 0.8, 1.0). As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
376
+ performance.
377
+ compression_ratio_threshold (`float`, *optional*):
378
+ Only relevant for long-form transcription. If defined, the zlib compression rate of each segment will be computed. If the compression rate of
379
+ a segment is higher than `compression_ratio_threshold`, temperature fallback is activated: the generated segment is discarded and the generation is
380
+ repeated using a higher temperature. The intuition behind this feature is that segments with very high compression rates
381
+ suffer from a lot of repetition. The unwanted repetition can be reduced by injecting more randomness by increasing the temperature. If `compression_ratio_threshold` is defined
382
+ make sure that `temperature` is a list of values. A common value for `compression_ratio_threshold` is 1.35.
383
+ As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
384
+ performance.
385
+ logprob_threshold (`float`, *optional*):
386
+ Only relevant for long-form transcription. If defined, the average log-probability of each segment will be computed. If the log-probability of
387
+ a given segment is lower than `logprob_threshold`, temperature fallback is activated: the generated segment is discarded and the generation is
388
+ repeated using a higher temperature. The intuition behind this feature is that segments of low log-probability
389
+ can be improved by injecting more randomness by increasing the temperature. If `logprob_threshold` is defined
390
+ make sure that `temperature` is a list of values. A common value for `logprob_threshold` is -1.0.
391
+ As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
392
+ performance.
393
+ no_speech_threshold (`float`, *optional*):
394
+ Only relevant for long-form transcription. If defined, the "no-speech" token combined with the `logprob_threshold`
395
+ is used to determine whether a segment contains only silence. In this case, the transcription for this segment
396
+ is skipped.
397
+ As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
398
+ performance.
399
+ num_segment_frames (`int`, *optional*):
400
+ The number of frames a single segment is made of. If not defined, `num_segment_frames` defaults to the model's stride
401
+ times the maximum input length.
402
+ attention_mask (`torch.Tensor`, *optional*):
403
+ `attention_mask` needs to be passed when doing long-form transcription using a batch size > 1.
404
+ time_precision (`int`, *optional*, defaults to 0.02):
405
+ The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts
406
+ for 20 ms.
407
+ return_token_timestamps (`bool`, *optional*):
408
+ Whether to return token-level timestamps with the text. This can be used with or without the
409
+ `return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
410
+ words.
411
+ return_segments (`bool`, *optional*, defaults to `False`):
412
+ Whether to additionally return a list of all segments. Note that this option can only be enabled
413
+ when doing long-form transcription.
414
+ return_dict_in_generate (`bool`, *optional*, defaults to `False`):
415
+ Whether or not to return a [`~utils.ModelOutput`] instead of just returning the generated tokens.
416
+ Note that when doing long-form transcription, `return_dict_in_generate` can only be enabled when
417
+ `return_segments` is set True. In this case the generation outputs of each segment is added to each
418
+ segment.
419
+ kwargs (`Dict[str, Any]`, *optional*):
420
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
421
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
422
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
423
+
424
+ Return:
425
+ [`~utils.ModelOutput`] or `torch.LongTensor` or `Dict[str, Any]`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
426
+ or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor` or a dict of segments when `return_segments=True`.
427
+
428
+ If the passed input is > 30 seconds / > 3000 mel input features and `return_segments=True` then a dictionary of generated sequence ids, called `sequences` and a list of each generated segment is returned.
429
+
430
+ else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are:
431
+
432
+ - [`~generation.GenerateEncoderDecoderOutput`],
433
+ - [`~generation.GenerateBeamEncoderDecoderOutput`]
434
+
435
+ else only the generated output sequence ids are returned.
436
+
437
+ Example:
438
+
439
+ - *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate.
440
+
441
+ ```python
442
+ >>> import torch
443
+ >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
444
+ >>> from datasets import load_dataset, Audio
445
+
446
+ >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
447
+ >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
448
+ >>> model.cuda() # doctest: +IGNORE_RESULT
449
+
450
+ >>> # load audios > 30 seconds
451
+ >>> ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
452
+ >>> # resample to 16kHz
453
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=16000))
454
+ >>> # take first 8 audios and retrieve array
455
+ >>> audio = ds[:8]["audio"]
456
+ >>> audio = [x["array"] for x in audio]
457
+
458
+ >>> # make sure to NOT truncate the input audio, to return the `attention_mask` and to pad to the longest audio
459
+ >>> inputs = processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000)
460
+ >>> inputs = inputs.to("cuda", torch.float32)
461
+
462
+ >>> # transcribe audio to ids
463
+ >>> generated_ids = model.generate(**inputs)
464
+
465
+ >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
466
+ >>> transcription[0]
467
+ " Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile."
468
+ ```
469
+
470
+ - *Shortform transcription*: If passed mel input features are < 30 seconds, the whole audio will be transcribed with a single call to generate.
471
+
472
+ ```python
473
+ >>> import torch
474
+ >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
475
+ >>> from datasets import load_dataset
476
+
477
+ >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
478
+ >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
479
+
480
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
481
+
482
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
483
+ >>> input_features = inputs.input_features
484
+
485
+ >>> generated_ids = model.generate(inputs=input_features)
486
+
487
+ >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
488
+ >>> transcription
489
+ ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
490
+ ```
491
+
492
+ """
493
+ # 0. deprecate old inputs
494
+ if "inputs" in kwargs:
495
+ input_features = kwargs.pop("inputs")
496
+ warnings.warn(
497
+ "The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
498
+ FutureWarning,
499
+ )
500
+
501
+ # 1. prepare generation config
502
+ generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
503
+
504
+ # 2. set global generate variables
505
+ input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
506
+ num_segment_frames = input_stride * self.config.max_source_positions
507
+ batch_size, total_input_frames = self._retrieve_total_input_frames(
508
+ input_features=input_features, input_stride=input_stride, kwargs=kwargs
509
+ )
510
+ is_shortform = total_input_frames <= num_segment_frames
511
+
512
+ # 3. Make sure generation config is correctly set
513
+ # Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
514
+ return_dict_in_generate = self._set_return_outputs(
515
+ return_dict_in_generate=return_dict_in_generate,
516
+ return_token_timestamps=return_token_timestamps,
517
+ logprob_threshold=logprob_threshold,
518
+ generation_config=generation_config,
519
+ )
520
+ timestamp_begin = self._set_return_timestamps(
521
+ return_timestamps=return_timestamps, is_shortform=is_shortform, generation_config=generation_config
522
+ )
523
+ self._set_language_and_task(
524
+ language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
525
+ )
526
+ self._set_num_frames(
527
+ return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs
528
+ )
529
+ self._set_thresholds_and_condition(
530
+ generation_config=generation_config,
531
+ logprob_threshold=logprob_threshold,
532
+ compression_ratio_threshold=compression_ratio_threshold,
533
+ no_speech_threshold=no_speech_threshold,
534
+ condition_on_prev_tokens=condition_on_prev_tokens,
535
+ )
536
+ self._set_prompt_condition_type(
537
+ generation_config=generation_config,
538
+ prompt_condition_type=prompt_condition_type,
539
+ )
540
+
541
+ kwargs["attention_mask"] = attention_mask
542
+ # pass self.config for backward compatibility
543
+ init_tokens = self._retrieve_init_tokens(
544
+ input_features,
545
+ batch_size=batch_size,
546
+ generation_config=generation_config,
547
+ config=self.config,
548
+ num_segment_frames=num_segment_frames,
549
+ kwargs=kwargs,
550
+ )
551
+ # passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
552
+ # where the input ids are handled explicitly by the generate method
553
+ self._check_decoder_input_ids(kwargs=kwargs)
554
+
555
+ # 3. Retrieve logits processors
556
+ device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
557
+ begin_index = init_tokens.shape[1]
558
+ logits_processor = self._retrieve_logit_processors(
559
+ generation_config=generation_config,
560
+ logits_processor=logits_processor,
561
+ begin_index=begin_index, # begin index is index of first generated decoder token
562
+ num_beams=kwargs.get("num_beams", 1),
563
+ device=device,
564
+ )
565
+
566
+ # 4 Set and retrieve global generation variables
567
+ self._set_condition_on_prev_tokens(
568
+ condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config
569
+ )
570
+
571
+ temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature
572
+ temperature = temperatures[0]
573
+
574
+ max_frames, seek = self._retrieve_max_frames_and_seek(
575
+ batch_size=batch_size,
576
+ attention_mask=attention_mask,
577
+ total_input_frames=total_input_frames,
578
+ is_shortform=is_shortform,
579
+ )
580
+
581
+ # 5 Prepare running variables, list for generation
582
+ num_return_sequences = generation_config.num_return_sequences
583
+ (
584
+ batch_idx_map,
585
+ cur_bsz,
586
+ input_features,
587
+ seek,
588
+ max_frames,
589
+ init_tokens,
590
+ do_condition_on_prev_tokens,
591
+ ) = self._expand_variables_for_generation(
592
+ input_features=input_features,
593
+ seek=seek,
594
+ max_frames=max_frames,
595
+ init_tokens=init_tokens,
596
+ batch_size=batch_size,
597
+ condition_on_prev_tokens=condition_on_prev_tokens,
598
+ generation_config=generation_config,
599
+ )
600
+
601
+ current_segments = self._prepare_segments(
602
+ prompt_ids=prompt_ids,
603
+ batch_size=cur_bsz,
604
+ generation_config=generation_config,
605
+ )
606
+
607
+ # 6 Transcribe audio until we reach the end of all input audios
608
+ while (seek < max_frames).any():
609
+ # 6.1 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop
610
+ # in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order
611
+ # to know which original audio is being decoded
612
+ # Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk
613
+ input_features, cur_bsz, batch_idx_map = self._maybe_reduce_batch(
614
+ input_features=input_features,
615
+ seek=seek,
616
+ max_frames=max_frames,
617
+ cur_bsz=cur_bsz,
618
+ batch_idx_map=batch_idx_map,
619
+ )
620
+ time_offset = seek * time_precision / input_stride
621
+ seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
622
+
623
+ # 6.2 cut out next 30s segment from input features
624
+ segment_input = self._get_input_segment(
625
+ input_features=input_features,
626
+ seek=seek,
627
+ seek_num_frames=seek_num_frames,
628
+ num_segment_frames=num_segment_frames,
629
+ cur_bsz=cur_bsz,
630
+ batch_idx_map=batch_idx_map,
631
+ )
632
+
633
+ # 6.3 prepare decoder input ids
634
+ suppress_tokens = _get_attr_from_logit_processors(
635
+ logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens"
636
+ )
637
+
638
+ decoder_input_ids, kwargs = self._prepare_decoder_input_ids(
639
+ cur_bsz=cur_bsz,
640
+ init_tokens=init_tokens,
641
+ current_segments=current_segments,
642
+ batch_idx_map=batch_idx_map,
643
+ do_condition_on_prev_tokens=do_condition_on_prev_tokens,
644
+ prompt_ids=prompt_ids,
645
+ generation_config=generation_config,
646
+ config=self.config,
647
+ device=init_tokens.device,
648
+ suppress_tokens=suppress_tokens,
649
+ kwargs=kwargs,
650
+ )
651
+
652
+ # 6.4 set max new tokens or max length
653
+ self._set_max_new_tokens_and_length(
654
+ config=self.config,
655
+ decoder_input_ids=decoder_input_ids,
656
+ generation_config=generation_config,
657
+ )
658
+
659
+ # 6.5 Set current `begin_index` for all logit processors
660
+ if logits_processor is not None:
661
+ for proc in logits_processor:
662
+ if hasattr(proc, "set_begin_index"):
663
+ proc.set_begin_index(decoder_input_ids.shape[-1])
664
+
665
+ # 6.6 Run generate with fallback
666
+ (
667
+ seek_sequences,
668
+ seek_outputs,
669
+ should_skip,
670
+ do_condition_on_prev_tokens,
671
+ model_output_type,
672
+ ) = self.generate_with_fallback(
673
+ segment_input=segment_input,
674
+ decoder_input_ids=decoder_input_ids,
675
+ cur_bsz=cur_bsz,
676
+ batch_idx_map=batch_idx_map,
677
+ seek=seek,
678
+ num_segment_frames=num_segment_frames,
679
+ max_frames=max_frames,
680
+ temperatures=temperatures,
681
+ generation_config=generation_config,
682
+ logits_processor=logits_processor,
683
+ stopping_criteria=stopping_criteria,
684
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
685
+ synced_gpus=synced_gpus,
686
+ return_token_timestamps=return_token_timestamps,
687
+ do_condition_on_prev_tokens=do_condition_on_prev_tokens,
688
+ is_shortform=is_shortform,
689
+ batch_size=batch_size,
690
+ kwargs=kwargs,
691
+ )
692
+
693
+ # 6.7 In every generated sequence, split by timestamp tokens and extract segments
694
+ for i, seek_sequence in enumerate(seek_sequences):
695
+ prev_i = batch_idx_map[i]
696
+
697
+ if should_skip[i]:
698
+ seek[prev_i] += seek_num_frames[prev_i]
699
+ continue
700
+
701
+ segments, segment_offset = self._retrieve_segment(
702
+ seek_sequence=seek_sequence,
703
+ seek_outputs=seek_outputs,
704
+ time_offset=time_offset,
705
+ timestamp_begin=timestamp_begin,
706
+ seek_num_frames=seek_num_frames,
707
+ time_precision=time_precision,
708
+ input_stride=input_stride,
709
+ prev_idx=prev_i,
710
+ idx=i,
711
+ return_token_timestamps=return_token_timestamps,
712
+ )
713
+
714
+ current_segments[prev_i] += segments
715
+
716
+ if is_shortform:
717
+ seek[prev_i] += max_frames[i]
718
+ else:
719
+ seek[prev_i] += segment_offset
720
+
721
+ # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
722
+ # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
723
+ final_segments = (
724
+ [x[1:] for x in current_segments]
725
+ if (prompt_ids is not None and generation_config.prompt_condition_type == "first-segment")
726
+ else current_segments
727
+ )
728
+
729
+ sequences = _pad_to_max_length(
730
+ final_segments, generation_config.pad_token_id, device=self.device, padding_side="right"
731
+ )
732
+
733
+ # 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
734
+ if return_segments:
735
+ return {"sequences": sequences, "segments": final_segments}
736
+
737
+ if is_shortform:
738
+ # add eos token:
739
+ if generation_config.max_new_tokens is None and generation_config.max_length is None:
740
+ eos_tokens = torch.full((sequences.shape[0], 1), generation_config.eos_token_id)
741
+ sequences = torch.cat([sequences, eos_tokens], dim=-1)
742
+
743
+ if return_token_timestamps:
744
+ outputs = {}
745
+ outputs["sequences"] = sequences
746
+ outputs["token_timestamps"] = torch.stack([d["token_timestamps"] for d in seek_outputs], dim=0)
747
+ else:
748
+ outputs = sequences
749
+
750
+ if return_dict_in_generate and generation_config.return_dict_in_generate:
751
+ dict_outputs = self._stack_split_outputs(seek_outputs, model_output_type, sequences.device, kwargs)
752
+
753
+ if num_return_sequences > 1:
754
+ if hasattr(dict_outputs, "encoder_attentions") and dict_outputs.encoder_attentions is not None:
755
+ dict_outputs.encoder_attentions = tuple(
756
+ dict_outputs.encoder_attentions[i][::num_return_sequences]
757
+ for i in range(len(dict_outputs.encoder_attentions))
758
+ )
759
+ if (
760
+ hasattr(dict_outputs, "encoder_hidden_states")
761
+ and dict_outputs.encoder_hidden_states is not None
762
+ ):
763
+ dict_outputs.encoder_hidden_states = tuple(
764
+ dict_outputs.encoder_hidden_states[i][::num_return_sequences]
765
+ for i in range(len(dict_outputs.encoder_hidden_states))
766
+ )
767
+ if return_token_timestamps:
768
+ dict_outputs["token_timestamps"] = outputs["token_timestamps"]
769
+ return dict_outputs
770
+
771
+ return outputs
772
+
773
+ return sequences
774
+
775
+ def generate_with_fallback(
776
+ self,
777
+ segment_input,
778
+ decoder_input_ids,
779
+ cur_bsz,
780
+ batch_idx_map,
781
+ seek,
782
+ num_segment_frames,
783
+ max_frames,
784
+ temperatures,
785
+ generation_config,
786
+ logits_processor,
787
+ stopping_criteria,
788
+ prefix_allowed_tokens_fn,
789
+ synced_gpus,
790
+ return_token_timestamps,
791
+ do_condition_on_prev_tokens,
792
+ is_shortform,
793
+ batch_size,
794
+ kwargs,
795
+ ):
796
+ kwargs = copy.copy(kwargs)
797
+
798
+ # 6.6 Batch generate current chunk
799
+ seek_sequence_list = [None for _ in range(cur_bsz)]
800
+ seek_outputs_list = [None for _ in range(cur_bsz)]
801
+ needs_fallback = [False for _ in range(cur_bsz)]
802
+ should_skip = [False for _ in range(cur_bsz)]
803
+ fallback_index_map = list(range(cur_bsz))
804
+ if generation_config.no_speech_threshold is not None:
805
+ self._setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs)
806
+
807
+ for fallback_idx, temperature in enumerate(temperatures):
808
+ generation_config.do_sample = temperature is not None and temperature > 0.0
809
+ generation_config.temperature = temperature if generation_config.do_sample else 1.0
810
+ if generation_config.do_sample:
811
+ generation_config.num_beams = 1
812
+
813
+ generate_kwargs = copy.copy(kwargs)
814
+ for key in ["do_sample", "temperature", "num_beams"]:
815
+ if key in generate_kwargs:
816
+ del generate_kwargs[key]
817
+
818
+ cur_bsz = decoder_input_ids.shape[0]
819
+ if generation_config.cache_implementation == "static" and cur_bsz < batch_size:
820
+ segment_input = F.pad(segment_input, (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0)
821
+ decoder_input_ids = F.pad(
822
+ decoder_input_ids, (0, 0, 0, batch_size - cur_bsz), value=generation_config.pad_token_id
823
+ )
824
+ if generate_kwargs.get("decoder_attention_mask") is not None:
825
+ generate_kwargs["decoder_attention_mask"] = F.pad(
826
+ generate_kwargs["decoder_attention_mask"], (0, 0, 0, batch_size - cur_bsz), value=True
827
+ )
828
+ if generate_kwargs.get("encoder_outputs") is not None:
829
+ generate_kwargs["encoder_outputs"] = F.pad(
830
+ generate_kwargs["encoder_outputs"], (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0
831
+ )
832
+
833
+ seek_outputs = super().generate(
834
+ segment_input,
835
+ generation_config=generation_config,
836
+ logits_processor=logits_processor,
837
+ stopping_criteria=stopping_criteria,
838
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
839
+ synced_gpus=synced_gpus,
840
+ decoder_input_ids=decoder_input_ids,
841
+ **generate_kwargs,
842
+ )
843
+
844
+ model_output_type = type(seek_outputs)
845
+
846
+ # post-process sequence tokens and outputs to be in list form
847
+ seek_sequences, seek_outputs = self._postprocess_outputs(
848
+ seek_outputs=seek_outputs,
849
+ decoder_input_ids=decoder_input_ids,
850
+ return_token_timestamps=return_token_timestamps,
851
+ generation_config=generation_config,
852
+ is_shortform=is_shortform,
853
+ )
854
+
855
+ if cur_bsz < batch_size:
856
+ seek_sequences = seek_sequences[:cur_bsz]
857
+ seek_outputs = seek_outputs[:cur_bsz]
858
+
859
+ # 6.7 Extract cut sequences from every sequence and check if fallback should be applied
860
+ # Loop over each decoded audio individually as each decoding can be of a different length
861
+ new_fallback_index_map = []
862
+ new_segment_input = []
863
+ new_decoder_input_ids = []
864
+ new_decoder_attention_mask = []
865
+
866
+ for i, seek_sequence in enumerate(seek_sequences):
867
+ # make sure we cut a predicted EOS token if we are not finished with the generation yet
868
+ prev_i = batch_idx_map[fallback_index_map[i]]
869
+ is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i]
870
+
871
+ # remove eos token id
872
+ if is_not_final and seek_sequence[-1] == generation_config.eos_token_id:
873
+ seek_sequence = seek_sequence[:-1]
874
+ if return_token_timestamps and not is_shortform:
875
+ seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-1]
876
+
877
+ # remove all padding tokens
878
+ if seek_sequence[-1] == generation_config.pad_token_id:
879
+ num_paddings = (seek_sequence == generation_config.pad_token_id).sum()
880
+ seek_sequence = seek_sequence[:-num_paddings]
881
+ if return_token_timestamps and not is_shortform:
882
+ seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-num_paddings]
883
+
884
+ # check which sequences in batch need fallback & which should be skipped
885
+ needs_fallback[i], should_skip[i] = self._need_fallback(
886
+ seek_sequence,
887
+ seek_outputs,
888
+ i,
889
+ logits_processor,
890
+ generation_config,
891
+ self.config.vocab_size,
892
+ temperature,
893
+ )
894
+
895
+ seek_sequence_list[fallback_index_map[i]] = seek_sequence
896
+ seek_outputs_list[fallback_index_map[i]] = seek_outputs[i]
897
+ is_low_temperature = temperature is None or temperature < 0.5
898
+ do_condition_on_prev_tokens[fallback_index_map[i]] = (
899
+ generation_config.condition_on_prev_tokens and is_low_temperature
900
+ )
901
+
902
+ if needs_fallback[i]:
903
+ new_fallback_index_map.append(fallback_index_map[i])
904
+ new_segment_input.append(segment_input[i])
905
+ new_decoder_input_ids.append(decoder_input_ids[i])
906
+ if "decoder_attention_mask" in kwargs:
907
+ new_decoder_attention_mask.append(kwargs["decoder_attention_mask"][i])
908
+
909
+ fallback_index_map = new_fallback_index_map
910
+
911
+ # if no sequence needs to be run with temperature fallback, we're finished
912
+ if len(fallback_index_map) == 0 or fallback_idx == len(temperatures) - 1:
913
+ seek_sequences = seek_sequence_list
914
+ seek_outputs = seek_outputs_list
915
+ break
916
+
917
+ # if we're still in the loop, make sure that decoder_input_ids and segment inputs are tensors
918
+ decoder_input_ids = torch.stack(new_decoder_input_ids)
919
+ segment_input = torch.stack(new_segment_input)
920
+ if "decoder_attention_mask" in kwargs:
921
+ kwargs["decoder_attention_mask"] = torch.stack(new_decoder_attention_mask)
922
+
923
+ return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens, model_output_type
924
+
925
+ @staticmethod
926
+ def _prepare_segments(prompt_ids, batch_size, generation_config):
927
+ if prompt_ids is not None and generation_config.prompt_condition_type == "first-segment":
928
+ prev_sot_token_id = getattr(generation_config, "prev_sot_token_id", None)
929
+ prompt_ids = prompt_ids[1:] if prompt_ids[0] == prev_sot_token_id else prompt_ids
930
+ current_segments = [[{"tokens": prompt_ids}] for _ in range(batch_size)]
931
+ else:
932
+ current_segments = [[] for _ in range(batch_size)]
933
+
934
+ return current_segments
935
+
936
+ def _postprocess_outputs(
937
+ self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config, is_shortform
938
+ ):
939
+ # remove all previously passed decoder input ids
940
+ start_idx = decoder_input_ids.shape[-1] if not is_shortform else torch.tensor(0)
941
+
942
+ if isinstance(seek_outputs, torch.Tensor):
943
+ seek_outputs = seek_outputs[:, start_idx:]
944
+ return seek_outputs, seek_outputs
945
+
946
+ if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
947
+ num_frames = getattr(generation_config, "num_frames", None)
948
+ seek_outputs["token_timestamps"] = self._extract_token_timestamps(
949
+ seek_outputs, generation_config.alignment_heads, num_frames=num_frames
950
+ )
951
+ seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, start_idx:]
952
+
953
+ seek_outputs["sequences"] = seek_outputs["sequences"][:, start_idx:]
954
+
955
+ def split_by_batch_index(values, key, batch_idx, is_shortform):
956
+ if key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
957
+ return [v[batch_idx].cpu() for v in values]
958
+ if key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]:
959
+ return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values)
960
+ elif key == "past_key_values":
961
+ if not is_shortform:
962
+ # we don't save `past_key_values` as this is too costly for longform
963
+ return None
964
+ elif isinstance(values, EncoderDecoderCache):
965
+ all_past_key_values = []
966
+ for layer_idx in range(self.config.decoder_layers):
967
+ layer_past_key_values = []
968
+ for cache_cls in [values.self_attention_cache, values.cross_attention_cache]:
969
+ for v in [cache_cls.key_cache, cache_cls.value_cache]:
970
+ layer_past_key_values.append(v[layer_idx][batch_idx][None].cpu())
971
+ all_past_key_values.append(tuple(layer_past_key_values))
972
+ return tuple(all_past_key_values)
973
+ else:
974
+ all_past_key_values = []
975
+ for v in range(len(values)):
976
+ layer_past_key_values = []
977
+ for w in values[v]:
978
+ layer_past_key_values.append(w[batch_idx][None].cpu())
979
+ all_past_key_values.append(tuple(layer_past_key_values))
980
+ return tuple(all_past_key_values)
981
+
982
+ return values[batch_idx].cpu()
983
+
984
+ sequence_tokens = seek_outputs["sequences"]
985
+ seek_outputs = [
986
+ {k: split_by_batch_index(v, k, i, is_shortform) for k, v in seek_outputs.items()}
987
+ for i in range(sequence_tokens.shape[0])
988
+ ]
989
+
990
+ return sequence_tokens, seek_outputs
991
+
992
+ def _stack_split_outputs(self, seek_outputs, model_output_type, device, kwargs):
993
+ # Stack back seek_outputs tensors after splitting them with the split_by_batch_index method
994
+ outputs = {}
995
+ for key in seek_outputs[0].keys():
996
+ if key == "sequences":
997
+ outputs[key] = torch.stack([v[key] for v in seek_outputs], dim=0).to(device)
998
+ if key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
999
+ outputs[key] = tuple(
1000
+ torch.stack([v[key][i] for v in seek_outputs]).to(device) for i in range(len(seek_outputs[0][key]))
1001
+ )
1002
+ if key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]:
1003
+ outputs[key] = tuple(
1004
+ tuple(
1005
+ torch.stack([v[key][i][j] for v in seek_outputs]).squeeze(1).to(device)
1006
+ for j in range(len(seek_outputs[0][key][0]))
1007
+ )
1008
+ for i in range(len(seek_outputs[0][key]))
1009
+ )
1010
+ if key == "past_key_values":
1011
+ past_key_value_type = kwargs.get("past_key_values")
1012
+ if seek_outputs[0][key] is not None:
1013
+ outputs[key] = tuple(
1014
+ tuple(
1015
+ torch.stack([v[key][i][j] for v in seek_outputs]).squeeze(1).to(device)
1016
+ for j in range(len(seek_outputs[0][key][0]))
1017
+ )
1018
+ for i in range(len(seek_outputs[0][key]))
1019
+ )
1020
+ if past_key_value_type is not None and isinstance(past_key_value_type, EncoderDecoderCache):
1021
+ outputs[key] = past_key_value_type.from_legacy_cache(outputs[key])
1022
+ else:
1023
+ outputs[key] = None
1024
+
1025
+ return model_output_type(**outputs)
1026
+
1027
+ def _need_fallback(
1028
+ self,
1029
+ seek_sequence,
1030
+ seek_outputs,
1031
+ index,
1032
+ logits_processor,
1033
+ generation_config,
1034
+ vocab_size,
1035
+ temperature,
1036
+ ):
1037
+ needs_fallback = False
1038
+ should_skip = False
1039
+ if generation_config.compression_ratio_threshold is not None:
1040
+ compression_ratio = self._retrieve_compression_ratio(seek_sequence, vocab_size)
1041
+
1042
+ if compression_ratio > generation_config.compression_ratio_threshold:
1043
+ needs_fallback = True
1044
+
1045
+ if generation_config.logprob_threshold is not None:
1046
+ if hasattr(seek_outputs[0], "sequences_scores"):
1047
+ logprobs = [s["sequences_scores"] for s in seek_outputs][index]
1048
+ else:
1049
+ scores = seek_outputs[index]["scores"]
1050
+ logprobs = self._retrieve_avg_logprobs(
1051
+ scores, seek_sequence, generation_config.eos_token_id, temperature
1052
+ )
1053
+
1054
+ if logprobs < generation_config.logprob_threshold:
1055
+ needs_fallback = True
1056
+
1057
+ if generation_config.no_speech_threshold is not None:
1058
+ no_speech_prob = _get_attr_from_logit_processors(
1059
+ logits_processor, WhisperNoSpeechDetection, "no_speech_prob"
1060
+ )
1061
+
1062
+ if (
1063
+ logprobs < generation_config.logprob_threshold
1064
+ and no_speech_prob[index] > generation_config.no_speech_threshold
1065
+ ):
1066
+ needs_fallback = False
1067
+ should_skip = True
1068
+
1069
+ return needs_fallback, should_skip
1070
+
1071
+ def _expand_variables_for_generation(
1072
+ self, input_features, seek, max_frames, init_tokens, batch_size, condition_on_prev_tokens, generation_config
1073
+ ):
1074
+ if generation_config.num_return_sequences is not None and generation_config.num_return_sequences > 1:
1075
+ batch_idx_map = list(range(batch_size * generation_config.num_return_sequences))
1076
+ cur_bsz = len(batch_idx_map)
1077
+ do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(len(batch_idx_map))]
1078
+ input_features = input_features.repeat_interleave(generation_config.num_return_sequences, dim=0)
1079
+ seek = seek.repeat_interleave(generation_config.num_return_sequences, dim=0)
1080
+ max_frames = max_frames.repeat_interleave(generation_config.num_return_sequences, dim=0)
1081
+ init_tokens = init_tokens.repeat_interleave(generation_config.num_return_sequences, dim=0)
1082
+ generation_config.num_return_sequences = 1
1083
+ else:
1084
+ cur_bsz = batch_size
1085
+ batch_idx_map = list(range(cur_bsz))
1086
+ do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(cur_bsz)]
1087
+
1088
+ return (
1089
+ batch_idx_map,
1090
+ cur_bsz,
1091
+ input_features,
1092
+ seek,
1093
+ max_frames,
1094
+ init_tokens,
1095
+ do_condition_on_prev_tokens,
1096
+ )
1097
+
1098
+ @staticmethod
1099
+ def _setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs):
1100
+ set_inputs = _get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "set_inputs")
1101
+ extra_kwargs = {k: v for k, v in kwargs.items() if torch.is_tensor(v)}
1102
+ set_inputs({"inputs": segment_input, "decoder_input_ids": decoder_input_ids, **extra_kwargs})
1103
+
1104
+ @staticmethod
1105
+ def _retrieve_total_input_frames(input_features, input_stride, kwargs):
1106
+ if input_features is not None:
1107
+ return input_features.shape[0], input_features.shape[-1]
1108
+
1109
+ if "encoder_outputs" in kwargs:
1110
+ encoder_outputs_shape = (
1111
+ kwargs["encoder_outputs"][0].shape
1112
+ if isinstance(kwargs["encoder_outputs"], BaseModelOutput)
1113
+ else kwargs["encoder_outputs"].shape
1114
+ )
1115
+ return encoder_outputs_shape[0], encoder_outputs_shape[1] * input_stride
1116
+
1117
+ raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.")
1118
+
1119
+ @staticmethod
1120
+ def _maybe_warn_unused_inputs(
1121
+ condition_on_prev_tokens,
1122
+ temperature,
1123
+ compression_ratio_threshold,
1124
+ logprob_threshold,
1125
+ no_speech_threshold,
1126
+ total_input_frames,
1127
+ ):
1128
+ warning_prefix = (
1129
+ f"Audio input consists of only {total_input_frames}. "
1130
+ "Short-form transcription is activated."
1131
+ "{}, but will be ignored."
1132
+ )
1133
+ if condition_on_prev_tokens is not None:
1134
+ logger.warning(warning_prefix.format(f"condition_on_prev_tokens is set to {condition_on_prev_tokens}"))
1135
+
1136
+ if compression_ratio_threshold is not None:
1137
+ logger.warning(
1138
+ warning_prefix.format(f"compression_ratio_threshold is set to {compression_ratio_threshold}")
1139
+ )
1140
+
1141
+ if logprob_threshold is not None:
1142
+ logger.warning(warning_prefix.format(f"logprob_threshold is set to {logprob_threshold}"))
1143
+
1144
+ if no_speech_threshold is not None:
1145
+ logger.warning(warning_prefix.format(f"no_speech_threshold is set to {no_speech_threshold}"))
1146
+
1147
+ # when passing temperature as a list it cannot just be ignored => throw error in this case
1148
+ if isinstance(temperature, (list, tuple)):
1149
+ raise ValueError(
1150
+ f"Audio input consists of only {total_input_frames}. Short-form transcription is activated."
1151
+ f"temperature cannot be set to {temperature} which can only be used for temperature fallback for long-form generation. Make sure to set `temperature` to a float value or `None` for short-form generation."
1152
+ )
1153
+
1154
+ @staticmethod
1155
+ def _set_return_outputs(return_dict_in_generate, return_token_timestamps, logprob_threshold, generation_config):
1156
+ if return_dict_in_generate is None:
1157
+ return_dict_in_generate = generation_config.return_dict_in_generate
1158
+ else:
1159
+ generation_config.return_dict_in_generate = return_dict_in_generate
1160
+
1161
+ generation_config.return_token_timestamps = return_token_timestamps
1162
+ if return_token_timestamps:
1163
+ generation_config.return_dict_in_generate = True
1164
+ generation_config.output_attentions = True
1165
+ generation_config.output_scores = True
1166
+
1167
+ if logprob_threshold is not None:
1168
+ generation_config.return_dict_in_generate = True
1169
+ generation_config.output_scores = True
1170
+
1171
+ return return_dict_in_generate
1172
+
1173
+ def _set_return_timestamps(self, return_timestamps, is_shortform, generation_config):
1174
+ if return_timestamps is None and hasattr(generation_config, "return_timestamps"):
1175
+ return_timestamps = generation_config.return_timestamps
1176
+
1177
+ if not is_shortform:
1178
+ if return_timestamps is False:
1179
+ raise ValueError(
1180
+ "You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
1181
+ "requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features."
1182
+ )
1183
+
1184
+ logger.info("Setting `return_timestamps=True` for long-form generation.")
1185
+ return_timestamps = True
1186
+
1187
+ if return_timestamps and not hasattr(generation_config, "no_timestamps_token_id"):
1188
+ raise ValueError(
1189
+ "You are trying to return timestamps, but the generation config is not properly set. "
1190
+ "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
1191
+ "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
1192
+ )
1193
+
1194
+ generation_config.return_timestamps = return_timestamps
1195
+
1196
+ if hasattr(generation_config, "no_timestamps_token_id"):
1197
+ timestamp_begin = generation_config.no_timestamps_token_id + 1
1198
+ else:
1199
+ # BC for models missing the `no_timestamps_token_id` in the generation config when generating short-form with no timestamps
1200
+ # We set the timestamp begin token larger than the vocab size, such that the timestamp condition is never met in the decoding loop
1201
+ timestamp_begin = self.config.vocab_size + 1
1202
+
1203
+ return timestamp_begin
1204
+
1205
+ @staticmethod
1206
+ def _set_language_and_task(language, task, is_multilingual, generation_config):
1207
+ if is_multilingual is not None:
1208
+ if not hasattr(generation_config, "is_multilingual"):
1209
+ raise ValueError(
1210
+ "The generation config is outdated and is thus not compatible with the `is_multilingual` argument "
1211
+ "to `generate`. Please update the generation config as per the instructions "
1212
+ "https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
1213
+ )
1214
+ generation_config.is_multilingual = is_multilingual
1215
+
1216
+ if hasattr(generation_config, "is_multilingual") and not generation_config.is_multilingual:
1217
+ if task is not None or language is not None:
1218
+ raise ValueError(
1219
+ "Cannot specify `task` or `language` for an English-only model. If the model is intended to be "
1220
+ "multilingual, pass `is_multilingual=True` to generate, or update the generation config."
1221
+ )
1222
+
1223
+ if language is not None:
1224
+ if not hasattr(generation_config, "lang_to_id"):
1225
+ raise ValueError(
1226
+ "The generation config is outdated and is thus not compatible with the `language` argument "
1227
+ "to `generate`. Either set the language using the `forced_decoder_ids` in the model config, "
1228
+ "or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
1229
+ )
1230
+ generation_config.language = language
1231
+
1232
+ if task is not None:
1233
+ if not hasattr(generation_config, "task_to_id"):
1234
+ raise ValueError(
1235
+ "The generation config is outdated and is thus not compatible with the `task` argument "
1236
+ "to `generate`. Either set the task using the `forced_decoder_ids` in the model config, "
1237
+ "or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
1238
+ )
1239
+ generation_config.task = task
1240
+
1241
+ def _retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs):
1242
+ def replace_or_add(lst: List[int], num: int, itr: Iterator[int]):
1243
+ """short function to replace num with a itr in lst"""
1244
+ found = any(i in lst for i in itr)
1245
+ if found:
1246
+ lst = [num if i in itr else i for i in lst]
1247
+ else:
1248
+ lst.append(num)
1249
+ return lst
1250
+
1251
+ def language_to_id(language: str) -> int:
1252
+ language = language.lower()
1253
+ if language in generation_config.lang_to_id.keys():
1254
+ language_token = language
1255
+ elif language in TO_LANGUAGE_CODE.keys():
1256
+ language_token = f"<|{TO_LANGUAGE_CODE[language]}|>"
1257
+ elif language in TO_LANGUAGE_CODE.values():
1258
+ language_token = f"<|{language}|>"
1259
+ else:
1260
+ is_language_code = len(language) == 2
1261
+ raise ValueError(
1262
+ f"Unsupported language: {language}. Language should be one of:"
1263
+ f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
1264
+ )
1265
+ if language_token not in generation_config.lang_to_id:
1266
+ raise ValueError(
1267
+ f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`."
1268
+ "(You should just add it to the generation config)"
1269
+ )
1270
+
1271
+ return generation_config.lang_to_id[language_token]
1272
+
1273
+ task = getattr(generation_config, "task", None)
1274
+ language = getattr(generation_config, "language", None)
1275
+
1276
+ forced_decoder_ids = generation_config.forced_decoder_ids
1277
+ if forced_decoder_ids is not None:
1278
+ if language is None and task is None and forced_decoder_ids[0][1] is None:
1279
+ logger.warning_once(
1280
+ "Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English."
1281
+ "This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`."
1282
+ )
1283
+ elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
1284
+ forced_decoder_ids = config.forced_decoder_ids
1285
+
1286
+ if forced_decoder_ids is not None and task is not None:
1287
+ logger.warning_once(
1288
+ f"You have passed task={task}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of task={task}."
1289
+ )
1290
+ forced_decoder_ids = None
1291
+ elif forced_decoder_ids is not None and language is not None:
1292
+ logger.warning_once(
1293
+ f"You have passed language={language}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of language={language}."
1294
+ )
1295
+ forced_decoder_ids = None
1296
+
1297
+ init_tokens = [generation_config.decoder_start_token_id]
1298
+ if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1:
1299
+ i = 1
1300
+ while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i:
1301
+ init_tokens += [forced_decoder_ids[0][1]]
1302
+ forced_decoder_ids = forced_decoder_ids[1:]
1303
+ i += 1
1304
+
1305
+ if len(forced_decoder_ids) > 0:
1306
+ raise ValueError(
1307
+ f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all indices >= 1 and < {forced_decoder_ids[0][0]}.",
1308
+ )
1309
+
1310
+ # from v4.39 the forced decoder ids are always None in favour of decoder input ids
1311
+ generation_config.forced_decoder_ids = None
1312
+
1313
+ is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
1314
+
1315
+ # Make sure language is a list of strings of the correct length
1316
+ if isinstance(language, (list, tuple)):
1317
+ if any(l is None for l in language):
1318
+ raise TypeError(
1319
+ "Expected `language` to be `None`, a single string (e.g. `'en'`), or a list of strings with length equal to the batch size (e.g. `('en', 'fr')` for a batch size of 2). Got a list containing `None`."
1320
+ )
1321
+ if len(language) != batch_size:
1322
+ raise ValueError(
1323
+ "When passing a list of languages, the length of the list must match the batch size. "
1324
+ f"Expected length of {batch_size}, but got {len(language)} languages."
1325
+ )
1326
+ languages = language
1327
+ elif language is None:
1328
+ # Language will be detected for each item in batch
1329
+ languages = [None] * batch_size
1330
+ else:
1331
+ languages = [language] # Use a length-1 list now, broadcast later
1332
+
1333
+ # Separate init_tokens for each language
1334
+ init_tokens = [copy.copy(init_tokens) for _ in languages]
1335
+
1336
+ # Update init_tokens with languages
1337
+ lang_ids = None
1338
+ if language is not None:
1339
+ lang_ids = [language_to_id(l) for l in languages]
1340
+ elif hasattr(generation_config, "lang_to_id") and is_lang_id_undefined:
1341
+ # language is not defined or intentially set to `None` to trigger language detection
1342
+ lang_ids = self.detect_language(
1343
+ input_features=input_features,
1344
+ encoder_outputs=kwargs.get("encoder_outputs", None),
1345
+ attention_mask=kwargs.get("attention_mask", None),
1346
+ generation_config=generation_config,
1347
+ num_segment_frames=num_segment_frames,
1348
+ ).tolist()
1349
+ if lang_ids is not None:
1350
+ # append or replace lang_ids to init_tokens
1351
+ for i in range(len(init_tokens)):
1352
+ if len(init_tokens[i]) > 1:
1353
+ init_tokens[i][1] = lang_ids[i]
1354
+ else:
1355
+ init_tokens[i].append(lang_ids[i])
1356
+ del languages
1357
+
1358
+ # Update init_tokens with task
1359
+ for i in range(len(init_tokens)):
1360
+ if task is not None:
1361
+ if task in TASK_IDS:
1362
+ init_tokens[i].append(generation_config.task_to_id[generation_config.task])
1363
+ task_id = generation_config.task_to_id[generation_config.task]
1364
+
1365
+ # if task is defined it'll overwrite task ids that might have already been defined via the generation_config
1366
+ replace_or_add(init_tokens[i], task_id, generation_config.task_to_id.values())
1367
+ else:
1368
+ raise ValueError(f"The `{task}`task is not supported. The task should be one of `{TASK_IDS}`")
1369
+ elif language is not None and hasattr(generation_config, "task_to_id"):
1370
+ # if language is defined, but no task id is in `init_tokens`, default to transcribe
1371
+ if not any(ti in init_tokens[i] for ti in generation_config.task_to_id.values()):
1372
+ init_tokens[i].append(generation_config.task_to_id["transcribe"])
1373
+
1374
+ if (
1375
+ not generation_config.return_timestamps
1376
+ and hasattr(generation_config, "no_timestamps_token_id")
1377
+ and init_tokens[i][-1] != generation_config.no_timestamps_token_id
1378
+ ):
1379
+ init_tokens[i].append(generation_config.no_timestamps_token_id)
1380
+ elif (
1381
+ generation_config.return_timestamps and init_tokens[i][-1] == generation_config.no_timestamps_token_id
1382
+ ):
1383
+ logger.info(
1384
+ "<|notimestamps|> prompt token is removed from generation_config since `return_timestamps` is set to `'True'`."
1385
+ )
1386
+ init_tokens[i] = init_tokens[i][:-1]
1387
+
1388
+ # let's make sure we don't pass `None` tokens as prompt tokens
1389
+ init_tokens[i] = [t for t in init_tokens[i] if t is not None]
1390
+
1391
+ return torch.as_tensor(init_tokens, dtype=torch.long, device=self.device).expand(batch_size, -1)
1392
+
1393
+ def detect_language(
1394
+ self,
1395
+ input_features: Optional[torch.FloatTensor] = None,
1396
+ attention_mask: Optional[torch.LongTensor] = None,
1397
+ encoder_outputs: Optional[Union[torch.FloatTensor, BaseModelOutput]] = None,
1398
+ generation_config: Optional[GenerationConfig] = None,
1399
+ num_segment_frames: int = 3000,
1400
+ ) -> torch.Tensor:
1401
+ """
1402
+ Detects language from log-mel input features or encoder_outputs
1403
+
1404
+ Parameters:
1405
+ input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):
1406
+ Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by
1407
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
1408
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
1409
+ [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
1410
+ tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.
1411
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
1412
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
1413
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
1414
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
1415
+ generation_config (`~generation.GenerationConfig`, *optional*):
1416
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
1417
+ passed to generate matching the attributes of `generation_config` will override them. If
1418
+ `generation_config` is not provided, the default will be used, which had the following loading
1419
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
1420
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
1421
+ default values, whose documentation should be checked to parameterize generation.
1422
+ num_segment_frames (`int`, *optional*, defaults to 3000):
1423
+ The number of log-mel frames the model expects
1424
+
1425
+ Return:
1426
+ A `torch.LongTensor` representing the detected language ids.
1427
+ """
1428
+ if input_features is None and encoder_outputs is None:
1429
+ raise ValueError("You have to specify either `input_features` or `encoder_outputs`")
1430
+ elif input_features is not None and encoder_outputs is not None:
1431
+ raise ValueError("Make sure to specificy only one of `input_features` or `encoder_outputs` - not both!")
1432
+ elif input_features is not None:
1433
+ inputs = {"input_features": input_features[:, :, :num_segment_frames]}
1434
+ batch_size = input_features.shape[0]
1435
+ elif encoder_outputs is not None:
1436
+ inputs = {"encoder_outputs": encoder_outputs}
1437
+ batch_size = (
1438
+ encoder_outputs[0].shape[0] if isinstance(encoder_outputs, BaseModelOutput) else encoder_outputs[0]
1439
+ )
1440
+ if attention_mask is not None:
1441
+ inputs["attention_mask"] = attention_mask
1442
+
1443
+ generation_config = generation_config or self.generation_config
1444
+ decoder_input_ids = (
1445
+ torch.ones((batch_size, 1), device=self.device, dtype=torch.long)
1446
+ * generation_config.decoder_start_token_id
1447
+ )
1448
+
1449
+ with torch.no_grad():
1450
+ logits = self(**inputs, decoder_input_ids=decoder_input_ids).logits[:, -1]
1451
+
1452
+ non_lang_mask = torch.ones_like(logits[0], dtype=torch.bool)
1453
+ non_lang_mask[list(generation_config.lang_to_id.values())] = False
1454
+
1455
+ logits[:, non_lang_mask] = -np.inf
1456
+
1457
+ lang_ids = logits.argmax(-1)
1458
+
1459
+ return lang_ids
1460
+
1461
+ @staticmethod
1462
+ def _check_decoder_input_ids(kwargs):
1463
+ decoder_input_ids = kwargs.get("decoder_input_ids", None)
1464
+ assistant_model = kwargs.get("assistant_model", None)
1465
+ if decoder_input_ids is not None and assistant_model is not None:
1466
+ raise ValueError(
1467
+ "Passing `decoder_input_ids` is deprecated. Consider passing `prompt_ids` instead.",
1468
+ )
1469
+
1470
+ @staticmethod
1471
+ def _set_num_frames(return_token_timestamps, generation_config, kwargs):
1472
+ if return_token_timestamps:
1473
+ if getattr(generation_config, "task", None) == "translate":
1474
+ logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
1475
+ if not hasattr(generation_config, "alignment_heads"):
1476
+ raise ValueError(
1477
+ "Model generation config has no `alignment_heads`, token-level timestamps not available. "
1478
+ "See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
1479
+ )
1480
+ generation_config.num_frames = kwargs.pop("num_frames", None)
1481
+
1482
+ @staticmethod
1483
+ def _set_thresholds_and_condition(
1484
+ generation_config,
1485
+ logprob_threshold,
1486
+ compression_ratio_threshold,
1487
+ no_speech_threshold,
1488
+ condition_on_prev_tokens,
1489
+ ):
1490
+ generation_config.logprob_threshold = (
1491
+ logprob_threshold
1492
+ if logprob_threshold is not None
1493
+ else getattr(generation_config, "logprob_threshold", None)
1494
+ )
1495
+ generation_config.compression_ratio_threshold = (
1496
+ compression_ratio_threshold
1497
+ if compression_ratio_threshold is not None
1498
+ else getattr(generation_config, "compression_ratio_threshold", None)
1499
+ )
1500
+ generation_config.no_speech_threshold = (
1501
+ no_speech_threshold
1502
+ if no_speech_threshold is not None
1503
+ else getattr(generation_config, "no_speech_threshold", None)
1504
+ )
1505
+ generation_config.condition_on_prev_tokens = (
1506
+ condition_on_prev_tokens
1507
+ if condition_on_prev_tokens is not None
1508
+ else getattr(generation_config, "condition_on_prev_tokens", None)
1509
+ )
1510
+
1511
+ @staticmethod
1512
+ def _set_prompt_condition_type(generation_config, prompt_condition_type):
1513
+ allowed_cond_types = ["first-segment", "all-segments"]
1514
+
1515
+ # default to "first-segment"
1516
+ prompt_condition_type = prompt_condition_type or allowed_cond_types[0]
1517
+
1518
+ if prompt_condition_type not in allowed_cond_types:
1519
+ raise ValueError(
1520
+ f"`prompt_condition_type={prompt_condition_type} does not exist. Make sure to set `prompt_condition_type` to one of {', '.join(allowed_cond_types)}"
1521
+ )
1522
+
1523
+ if generation_config.condition_on_prev_tokens is not True and prompt_condition_type == "all-segments":
1524
+ raise ValueError(
1525
+ "Make sure to set `condition_on_prev_tokens=True` when setting `prompt_condition_type='all-segments'`."
1526
+ )
1527
+
1528
+ generation_config.prompt_condition_type = prompt_condition_type
1529
+
1530
+ @staticmethod
1531
+ def _set_condition_on_prev_tokens(condition_on_prev_tokens, generation_config):
1532
+ condition_on_prev_tokens = (
1533
+ condition_on_prev_tokens
1534
+ if condition_on_prev_tokens is not None
1535
+ else getattr(generation_config, "condition_on_prev_tokens", False)
1536
+ )
1537
+ generation_config.condition_on_prev_tokens = condition_on_prev_tokens
1538
+
1539
+ @staticmethod
1540
+ def _retrieve_max_frames_and_seek(batch_size, attention_mask, total_input_frames, is_shortform):
1541
+ if batch_size > 1 and not is_shortform and attention_mask is None:
1542
+ raise ValueError(
1543
+ "When doing batched long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` "
1544
+ )
1545
+ elif batch_size > 1 and not is_shortform:
1546
+ max_frames = attention_mask.sum(-1).cpu().to(torch.long)
1547
+ seek = torch.zeros((batch_size,), dtype=torch.long)
1548
+ else:
1549
+ max_frames = torch.ones((batch_size,), dtype=torch.long) * total_input_frames
1550
+ seek = torch.zeros((batch_size,), dtype=torch.long)
1551
+
1552
+ return max_frames, seek
1553
+
1554
+ def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, num_beams, device):
1555
+ if generation_config.return_timestamps is True:
1556
+ timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index)
1557
+ logits_processor = (
1558
+ [timestamp_processor] if logits_processor is None else [timestamp_processor] + logits_processor
1559
+ )
1560
+
1561
+ if generation_config.suppress_tokens is not None:
1562
+ suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens, device=device)
1563
+ logits_processor = (
1564
+ [suppress_tokens_processor]
1565
+ if logits_processor is None
1566
+ else [suppress_tokens_processor] + logits_processor
1567
+ )
1568
+ generation_config.suppress_tokens = None
1569
+
1570
+ if generation_config.begin_suppress_tokens is not None:
1571
+ begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor(
1572
+ generation_config.begin_suppress_tokens, begin_index=begin_index, device=device
1573
+ )
1574
+ logits_processor = (
1575
+ [begin_suppress_processor]
1576
+ if logits_processor is None
1577
+ else [begin_suppress_processor] + logits_processor
1578
+ )
1579
+ generation_config.begin_suppress_tokens = None
1580
+
1581
+ if generation_config.no_speech_threshold is not None:
1582
+ no_speech_detector = WhisperNoSpeechDetection(
1583
+ no_speech_token=generation_config.no_timestamps_token_id - 1,
1584
+ begin_index=begin_index,
1585
+ scores_is_logprobs=num_beams > 1,
1586
+ )
1587
+ logits_processor = (
1588
+ [no_speech_detector] if logits_processor is None else [no_speech_detector] + logits_processor
1589
+ )
1590
+ no_speech_detector.set_model(self)
1591
+
1592
+ return logits_processor
1593
+
1594
+ @staticmethod
1595
+ def _maybe_reduce_batch(input_features, seek, max_frames, cur_bsz, batch_idx_map):
1596
+ prev_bsz = cur_bsz
1597
+ new_batch_idx_map = []
1598
+ for i in range(prev_bsz):
1599
+ prev_i = batch_idx_map[i]
1600
+ if seek[prev_i] >= max_frames[prev_i]:
1601
+ cut_index = i + (cur_bsz - prev_bsz)
1602
+ cur_bsz -= 1
1603
+ input_features = torch.cat([input_features[:cut_index], input_features[cut_index + 1 :]], dim=0)
1604
+ else:
1605
+ # cut out index that goes away
1606
+ new_batch_idx_map.append(prev_i)
1607
+
1608
+ return input_features, cur_bsz, new_batch_idx_map
1609
+
1610
+ @staticmethod
1611
+ def _get_input_segment(input_features, seek, seek_num_frames, num_segment_frames, cur_bsz, batch_idx_map):
1612
+ if input_features is None:
1613
+ return None
1614
+
1615
+ segment_input = []
1616
+ for i in range(cur_bsz):
1617
+ prev_i = batch_idx_map[i]
1618
+ segment_input_slice = input_features[i : i + 1, :, seek[prev_i] : seek[prev_i] + seek_num_frames[prev_i]]
1619
+
1620
+ if segment_input_slice.shape[-1] < num_segment_frames:
1621
+ # pad to 3000 if necessary
1622
+ segment_input_slice = F.pad(
1623
+ segment_input_slice, pad=(0, num_segment_frames - segment_input_slice.shape[-1])
1624
+ )
1625
+
1626
+ segment_input.append(segment_input_slice)
1627
+
1628
+ segment_input = torch.cat(segment_input, dim=0)
1629
+
1630
+ return segment_input
1631
+
1632
+ @staticmethod
1633
+ def _prepare_decoder_input_ids(
1634
+ cur_bsz,
1635
+ init_tokens,
1636
+ current_segments,
1637
+ batch_idx_map,
1638
+ do_condition_on_prev_tokens,
1639
+ prompt_ids,
1640
+ generation_config,
1641
+ config,
1642
+ device,
1643
+ suppress_tokens,
1644
+ kwargs,
1645
+ ):
1646
+ if "decoder_input_ids" in kwargs:
1647
+ decoder_input_ids = kwargs.pop("decoder_input_ids")
1648
+
1649
+ return decoder_input_ids, kwargs
1650
+
1651
+ cut_off_length = config.max_target_positions // 2 - 1
1652
+
1653
+ decoder_input_ids = init_tokens[batch_idx_map]
1654
+
1655
+ prev_start_of_text = getattr(generation_config, "prev_sot_token_id", None)
1656
+ if prev_start_of_text is None:
1657
+ prev_start_of_text = suppress_tokens[-2] if suppress_tokens is not None else None
1658
+
1659
+ if any(do_condition_on_prev_tokens) and len(current_segments[0]) > 0:
1660
+ # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609
1661
+ active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map]
1662
+
1663
+ if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments":
1664
+ prev_ids = prompt_ids
1665
+ else:
1666
+ one_tensor = torch.ones((cur_bsz, 1), device=device, dtype=torch.long)
1667
+ prev_ids = prev_start_of_text * one_tensor[0] if prev_start_of_text is not None else None
1668
+
1669
+ padding = "max_length" if generation_config.cache_implementation == "static" else "longest"
1670
+
1671
+ prev_tokens = _pad_to_max_length(
1672
+ active_segments,
1673
+ generation_config.pad_token_id,
1674
+ device=device,
1675
+ padding_side="left",
1676
+ padding=padding,
1677
+ bos_token_tensor=prev_ids,
1678
+ cut_off_length=cut_off_length,
1679
+ )
1680
+ decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
1681
+
1682
+ kwargs["decoder_attention_mask"] = decoder_input_ids != generation_config.pad_token_id
1683
+ elif prompt_ids is not None:
1684
+ prev_tokens = prompt_ids[None].repeat(decoder_input_ids.shape[0], 1)
1685
+ decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
1686
+ # make sure `"decoder_attention_mask"` is not passed to forward
1687
+ kwargs.pop("decoder_attention_mask", None)
1688
+ else:
1689
+ # make sure `"decoder_attention_mask"` is not passed to forward
1690
+ kwargs.pop("decoder_attention_mask", None)
1691
+
1692
+ return decoder_input_ids, kwargs
1693
+
1694
+ def _set_max_new_tokens_and_length(self, config, decoder_input_ids, generation_config):
1695
+ max_new_tokens = generation_config.max_new_tokens if generation_config.max_new_tokens is not None else 0
1696
+ if max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions:
1697
+ raise ValueError(
1698
+ f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` "
1699
+ f"is {max_new_tokens}. Thus, the combined length of "
1700
+ f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the "
1701
+ f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. "
1702
+ "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "
1703
+ f"so that their combined length is less than {self.config.max_target_positions}."
1704
+ )
1705
+
1706
+ num_initial_tokens = min(config.max_target_positions // 2 - 1, decoder_input_ids.shape[-1] - 1)
1707
+
1708
+ # Make sure we don't get larger than `max_length`
1709
+ if generation_config.max_length is not None and generation_config.max_new_tokens is None:
1710
+ max_length = min(generation_config.max_length + num_initial_tokens, config.max_target_positions)
1711
+ logger.info(
1712
+ f"Increase max_length from {generation_config.max_length} to {max_length} since input is conditioned on previous segment."
1713
+ )
1714
+ elif (
1715
+ generation_config.max_new_tokens is not None
1716
+ and generation_config.max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions
1717
+ ):
1718
+ max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1]
1719
+ generation_config.max_new_tokens = max_new_tokens
1720
+
1721
+ @staticmethod
1722
+ def _retrieve_compression_ratio(tokens, vocab_size):
1723
+ """Compute byte length of zlib compressed token bytes vs. byte length of raw token bytes"""
1724
+ length = int(math.log2(vocab_size) / 8) + 1
1725
+ token_bytes = b"".join([t.to_bytes(length, "little") for t in tokens.tolist()])
1726
+ compression_ratio = len(token_bytes) / len(zlib.compress(token_bytes))
1727
+
1728
+ return compression_ratio
1729
+
1730
+ @staticmethod
1731
+ def _retrieve_avg_logprobs(scores, tokens, eos_token_id, temperature):
1732
+ rescale_temperature = temperature if temperature > 0.0 else 1
1733
+ scores = torch.stack(scores).to(tokens.device)
1734
+
1735
+ if scores.shape[0] > tokens.shape[0]:
1736
+ scores = scores[: tokens.shape[0]]
1737
+ else:
1738
+ tokens = tokens[-scores.shape[0] :]
1739
+
1740
+ logprobs = F.log_softmax((scores * rescale_temperature).float(), dim=-1).to(scores.dtype)
1741
+
1742
+ # retrieve logprob of selected tokens and sum
1743
+ sum_logprobs = sum((logprobs[i][tokens[i]] * (tokens[i] != eos_token_id)) for i in range(logprobs.shape[0]))
1744
+ length = (tokens != eos_token_id).sum(-1) if eos_token_id is not None else tokens.shape[0]
1745
+
1746
+ avg_logprobs = sum_logprobs / (length + 1)
1747
+ return avg_logprobs
1748
+
1749
+ @staticmethod
1750
+ def _retrieve_segment(
1751
+ seek_sequence,
1752
+ seek_outputs,
1753
+ time_offset,
1754
+ timestamp_begin,
1755
+ seek_num_frames,
1756
+ time_precision,
1757
+ input_stride,
1758
+ prev_idx,
1759
+ idx,
1760
+ return_token_timestamps,
1761
+ ):
1762
+ # find the predicted "end of segment" predictions of Whisper
1763
+ # "end of segment" predictions occur whenever Whisper predicts a timestamp token
1764
+ timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
1765
+ single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
1766
+ timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
1767
+ timestamp_segment_indices.add_(1)
1768
+ token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
1769
+
1770
+ # If whisper predicted a "end of segment" via a timestep token, let's go ever each
1771
+ # "end of segment" prediction and slice the decoding into segments accordingly
1772
+ if len(timestamp_segment_indices) > 0:
1773
+ # if the output contains two consecutive timestamp tokens
1774
+ slices = timestamp_segment_indices.tolist()
1775
+ segments = []
1776
+ if single_timestamp_ending:
1777
+ slices.append(len(seek_sequence))
1778
+
1779
+ last_slice = 0
1780
+ # Add each segment to list of all segments
1781
+ for current_slice in slices:
1782
+ sliced_tokens = seek_sequence[last_slice:current_slice]
1783
+ start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin
1784
+ end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin
1785
+ segments.append(
1786
+ {
1787
+ "start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
1788
+ "end": time_offset[prev_idx] + end_timestamp_pos * time_precision,
1789
+ "tokens": sliced_tokens,
1790
+ "result": seek_outputs[idx],
1791
+ }
1792
+ )
1793
+ if return_token_timestamps:
1794
+ segments[-1]["token_timestamps"] = (
1795
+ token_timestamps[last_slice:current_slice] + time_offset[prev_idx]
1796
+ )
1797
+ last_slice = current_slice
1798
+
1799
+ if single_timestamp_ending:
1800
+ # single timestamp at the end means no speech after the last timestamp.
1801
+ segment_offset = seek_num_frames[prev_idx]
1802
+ else:
1803
+ # otherwise, ignore the unfinished segment and seek to the last timestamp
1804
+ # here we throw away all predictions after the last predicted "end of segment"
1805
+ # since we are cutting right in the middle of an audio
1806
+ last_timestamp_pos = seek_sequence[last_slice - 1].item() - timestamp_begin
1807
+ segment_offset = last_timestamp_pos * input_stride
1808
+ else:
1809
+ # If whisper does not predict any "end of segment" token, then
1810
+ # the whole decoding is considered a segment and we add it to the list of segments
1811
+ timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
1812
+ last_timestamp_pos = seek_num_frames[prev_idx]
1813
+ if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin:
1814
+ # no consecutive timestamps but it has a timestamp; use the last one.
1815
+ last_timestamp_pos = timestamps[-1].item() - timestamp_begin
1816
+ segments = [
1817
+ {
1818
+ "start": time_offset[prev_idx],
1819
+ "end": time_offset[prev_idx] + last_timestamp_pos * time_precision,
1820
+ "tokens": seek_sequence,
1821
+ "result": seek_outputs[idx],
1822
+ }
1823
+ ]
1824
+ if return_token_timestamps:
1825
+ segments[-1]["token_timestamps"] = token_timestamps + time_offset[prev_idx]
1826
+ segment_offset = seek_num_frames[prev_idx]
1827
+
1828
+ return segments, segment_offset
models/glm_speech_tokenizer/modeling_whisper.py ADDED
The diff for this file is too large to render. See raw diff
 
models/glm_speech_tokenizer/speech_token_extractor.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append("../../..")
4
+ import io
5
+ import glob
6
+ import math
7
+ import tarfile
8
+ import torch
9
+ import torchaudio
10
+ import safetensors
11
+ from .configuration_whisper import WhisperVQConfig
12
+ from .modeling_whisper import WhisperVQEncoder, WhisperVQForConditionalGeneration
13
+ from transformers import WhisperFeatureExtractor, WhisperTokenizerFast
14
+ import asyncio
15
+ from .batch_processor import AsyncBatchEngine # 修改为你的路径
16
+ from typing import List, Union, Tuple, Literal, Optional
17
+
18
+
19
+ class SpeechTokenExtractor:
20
+ def __init__(
21
+ self,
22
+ model: WhisperVQEncoder,
23
+ feature_extractor: WhisperFeatureExtractor,
24
+ device: Literal["cpu", "cuda", "mps"] | str = "cuda",
25
+ batch_size: int = 32,
26
+ wait_timeout: float = 0.01,
27
+ ):
28
+ self.model = model.eval().to(device)
29
+ self.feature_extractor = feature_extractor
30
+ self.device = device
31
+ self.wait_timeout = wait_timeout
32
+ self.dtype = next(model.parameters()).dtype
33
+
34
+ # 帧/采样 stride(用于 pad 对齐 & mask 下采样)
35
+ self.pooling_kernel_size = getattr(model.config, "pooling_kernel_size", 1)
36
+ self.frame_stride = (
37
+ model.conv1.stride[0] *
38
+ model.conv2.stride[0] *
39
+ self.pooling_kernel_size
40
+ )
41
+ self.sample_stride = self.frame_stride * feature_extractor.hop_length
42
+
43
+ # 重采样缓存(放在 device 上)
44
+ self._resamplers: dict[int, torchaudio.transforms.Resample] = {}
45
+
46
+ self._batch_processor = AsyncBatchEngine(
47
+ processing_function=self._batch_extract_async,
48
+ batch_size=batch_size,
49
+ wait_timeout=wait_timeout,
50
+ )
51
+
52
+ # -------- I/O & 重采样:保持在 device 上 --------
53
+ def _load_audio(self, utt: Union[str, torch.Tensor]) -> torch.Tensor:
54
+ """读取单条音频 -> 1D float32 waveform(在 self.device 上,采样率16k)。"""
55
+ # print(f"audio type is {type(utt)}")
56
+ if isinstance(utt, torch.Tensor):
57
+ # audio, sr = utt
58
+ audio = utt.to(self.device, non_blocking=True)
59
+ else:
60
+ audio, sr = torchaudio.load(utt) # CPU
61
+ if audio.ndim > 1 and audio.size(0) > 1: # 混单声道
62
+ audio = audio.mean(dim=0, keepdim=True)
63
+ audio = audio.squeeze(0).to(torch.float32).to(self.device, non_blocking=True)
64
+
65
+ return audio # [T] on device
66
+
67
+ # -------- GPU 上做 feature_extractor --------
68
+ def _extract_features_gpu(self, audios: List[torch.Tensor]) -> dict:
69
+ """
70
+ 1) 输入统一转 CPU numpy(float32)(FE 的要求)
71
+ 2) 调用 FE,并传 device=self.device,让“输出张量”直接落在 GPU
72
+ 3) 若模型是 fp16,仅将 input_features 转 half(mask 不动)
73
+ """
74
+ # 1) CUDA/CPU Tensor -> CPU numpy
75
+ np_audios = [a.detach().cpu().numpy().astype("float32") for a in audios]
76
+
77
+
78
+ feats = self.feature_extractor(
79
+ np_audios,
80
+ sampling_rate=16000,
81
+ return_attention_mask=True,
82
+ return_tensors="pt",
83
+ device=self.device, # ← 用得上
84
+ padding="longest",
85
+ pad_to_multiple_of=self.sample_stride,
86
+ )
87
+
88
+ feats = {k: (v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else v)
89
+ for k, v in feats.items()}
90
+
91
+ # 3) 半精度对齐(只对 input_features)
92
+ if self.dtype == torch.float16 and "input_features" in feats:
93
+ feats["input_features"] = feats["input_features"].half()
94
+
95
+ return feats
96
+
97
+
98
+ def _forward(self, feats: dict) -> List[List[int]]:
99
+ outputs = self.model(**feats)
100
+ tokens = outputs.quantized_token_ids
101
+ # mask 下采样对齐:conv 下采样 × pooling
102
+ attn = feats["attention_mask"][
103
+ :, :: self.model.conv1.stride[0] * self.model.conv2.stride[0]
104
+ ][:, :: self.pooling_kernel_size]
105
+ return [t[m.bool()].tolist() for t, m in zip(tokens, attn)]
106
+
107
+ # -------- 同步批接口 --------
108
+ def extract(self, utts: List[Union[str, torch.Tensor]]) -> List[List[int]]:
109
+ """
110
+ 不做 30s 分片,也不做 microbatch。
111
+ 直接:加载/重采样 -> GPU 特征提取 -> 前向 -> 对齐输出。
112
+ """
113
+ audios = [self._load_audio(u) for u in utts] # list[Tensor(T)] on device
114
+ with torch.inference_mode():
115
+ feats = self._extract_features_gpu(audios) # on device
116
+ return self._forward(feats)
117
+
118
+ # -------- 异步批接口(保持你的返回协议)--------
119
+ async def _batch_extract_async(self, utts: List[Union[str, torch.Tensor]]):
120
+ tokens_list = await asyncio.to_thread(self.extract, utts)
121
+ return [{"tokens": t} for t in tokens_list]
122
+
123
+ async def extract_async(self, utt: Union[str, torch.Tensor]):
124
+ result = await self._batch_processor.add_request(single_input=utt)
125
+ feature = result.get("feature")
126
+ return feature.get("tokens")
models/glm_speech_tokenizer/test_speech_token_extractor.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ #!/usr/bin/env python3
4
+ # -*- coding: utf-8 -*-
5
+
6
+ import os
7
+ import sys
8
+ sys.path.append("../../..")
9
+ import asyncio
10
+ import time
11
+ from datetime import datetime
12
+
13
+ import torch
14
+ import torchaudio
15
+ from transformers import WhisperFeatureExtractor
16
+ from arktts.models.glm_speech_tokenizer.modeling_whisper import WhisperVQEncoder
17
+ from speech_token_extractor import SpeechTokenExtractor # 你实现的类
18
+ _RESAMPLE_CACHE: dict[int, torchaudio.transforms.Resample] = {}
19
+
20
+ def ts() -> str:
21
+ return datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
22
+
23
+ def sync_cuda(device: str):
24
+ if isinstance(device, str) and device.startswith("cuda") and torch.cuda.is_available():
25
+ torch.cuda.synchronize(device=device)
26
+
27
+ def load_wav_as_tuple(path: str,target_sr: int = 16000):
28
+ """读取 wav -> (mono_waveform_1d, sample_rate);保持在CPU上交给 extractor 处理。"""
29
+ wav, sr = torchaudio.load(path) # [C, T]
30
+
31
+ if wav.ndim == 2 and wav.size(0) > 1:
32
+ wav = wav.mean(dim=0) # -> [T] 变单声道
33
+ else:
34
+ wav = wav.squeeze(0) # [1, T] -> [T]
35
+ # 保证是连续的 float32(特征器吃 numpy.float32 会更快)
36
+ wav = wav.contiguous().to(torch.float32).cpu()
37
+ if sr != target_sr:
38
+ if sr not in _RESAMPLE_CACHE:
39
+ _RESAMPLE_CACHE[sr] = torchaudio.transforms.Resample(
40
+ orig_freq=sr, new_freq=target_sr
41
+ )
42
+ wav = _RESAMPLE_CACHE[sr](wav.unsqueeze(0)).squeeze(0)
43
+ sr = target_sr
44
+
45
+ # print(f"type wave is {type(wav)}")
46
+ return wav
47
+
48
+ async def main():
49
+ # --- 1️⃣ 路径配置 ---
50
+ MODEL_PATH = "/data/yumu/model/glm-4-voice-tokenizer"
51
+ AUDIO_PATH1 = "/data/yumu/data/audio_data/qiduoduo_tts_out/00000013.wav"
52
+ AUDIO_PATH2 = "/data/yumu/data/audio_data/qiduoduo_tts_out/00000012.wav"
53
+ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
54
+
55
+ assert os.path.exists(AUDIO_PATH1), f"音频文件不存在: {AUDIO_PATH1}"
56
+ assert os.path.exists(MODEL_PATH), f"模型路径不存在: {MODEL_PATH}"
57
+
58
+ print(f"[{ts()}] 启动测试")
59
+ print(f" - DEVICE : {DEVICE}")
60
+ print(f" - MODEL_PATH : {MODEL_PATH}")
61
+ print(f" - AUDIO1 : {AUDIO_PATH1}")
62
+ print(f" - AUDIO2 : {AUDIO_PATH2 if os.path.exists(AUDIO_PATH2) else '(不存在,将重复 AUDIO1)'}")
63
+
64
+ # --- 2️⃣ 先把音频读入内存(改动点)---
65
+ audio1 = load_wav_as_tuple(AUDIO_PATH1)
66
+ audio2 = load_wav_as_tuple(AUDIO_PATH2) if os.path.exists(AUDIO_PATH2) else audio1
67
+
68
+ # --- 3️⃣ 加载模型与特征提取器 ---
69
+ print(f"\n[{ts()}] 加载 WhisperVQ 模型与特征提取器中...")
70
+ t0 = time.perf_counter()
71
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(MODEL_PATH)
72
+
73
+ model = WhisperVQEncoder.from_pretrained(MODEL_PATH).eval().to(DEVICE)
74
+ if DEVICE.startswith("cuda"):
75
+ model = model.half() # 半精度仅保留一次
76
+ sync_cuda(DEVICE)
77
+ t1 = time.perf_counter()
78
+ print(f"[{ts()}] 模型加载完成,用时 {(t1 - t0)*1000:.1f} ms")
79
+
80
+ # --- 4️⃣ 初始化提取器 ---
81
+ t0 = time.perf_counter()
82
+ extractor = SpeechTokenExtractor(
83
+ model=model,
84
+ feature_extractor=feature_extractor,
85
+ device=DEVICE,
86
+ batch_size=400,
87
+ wait_timeout=0.01,
88
+ )
89
+ sync_cuda(DEVICE)
90
+ t1 = time.perf_counter()
91
+ print(f"[{ts()}] ✅ SpeechTokenExtractor 初始化完成,用时 {(t1 - t0)*1000:.1f} ms")
92
+
93
+ # --- 5️⃣ 同步测试(传入预加载的 (wav, sr) 元组)---
94
+ print(f"\n[{ts()}] [同步模式] extract() 开始")
95
+ t0 = time.perf_counter()
96
+ sync_tokens_list = extractor.extract([audio1]) # ★ 改:不再传路径
97
+ sync_cuda(DEVICE)
98
+ t1 = time.perf_counter()
99
+ sync_tokens = sync_tokens_list[0]
100
+ print(f"[{ts()}] [同步模式] 完成:{len(sync_tokens)} tokens")
101
+ print(f" - 预览:{sync_tokens[:20]} ...")
102
+ print(f" - 耗时:{(t1 - t0)*1000:.1f} ms (单样本)")
103
+
104
+ # --- 6️⃣ 异步测试(同样传入元组)---
105
+ print(f"\n[{ts()}] [异步模式] extract_async() 并发开始")
106
+
107
+ async def async_worker(audio_utt):
108
+ t_a0 = time.perf_counter()
109
+ print(f"type audio_utt is {type(audio_utt)}")
110
+ tokens = await extractor.extract_async(audio_utt) # ★ 改:不再传路径
111
+ sync_cuda(DEVICE)
112
+ t_a1 = time.perf_counter()
113
+ print(f" · → {len(tokens)} tokens, {(t_a1 - t_a0)*1000:.1f} ms")
114
+ return tokens, (t_a1 - t_a0)
115
+
116
+ # 这里保持你原本的 20+20 并发规模,只是把对象换成内存元组
117
+ test_inputs = [audio1] * 2 + [audio2] * 2
118
+
119
+ t0 = time.perf_counter()
120
+ results = await asyncio.gather(*(async_worker(aud) for aud in test_inputs))
121
+ sync_cuda(DEVICE)
122
+ t1 = time.perf_counter()
123
+
124
+ per_req_ms = [dt * 1000 for _, dt in results]
125
+ all_tokens = [tokens for tokens, _ in results]
126
+
127
+ print(f"[{ts()}] [异步模式] 完成")
128
+ print(f" - 总请求数:{len(results)}")
129
+ print(f" - 总耗时 :{(t1 - t0)*1000:.1f} ms")
130
+ print(f" - 单请求耗时(ms):{[round(x,1) for x in per_req_ms]}")
131
+ print(f" - 平均单请求耗时:{(sum(per_req_ms)/len(per_req_ms)):.1f} ms")
132
+ print(f" - 任一结果预览 :{all_tokens[0][:10]}")
133
+ print(f"\n[{ts()}] ✅ 所有测试完成。")
134
+
135
+ if __name__ == "__main__":
136
+ asyncio.run(main())
models/glm_speech_tokenizer/utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import glob
4
+ import math
5
+ import tarfile
6
+ import torch
7
+ import torchaudio
8
+ import safetensors
9
+ from .configuration_whisper import WhisperVQConfig
10
+ from .modeling_whisper import WhisperVQEncoder, WhisperVQForConditionalGeneration
11
+ from transformers import WhisperFeatureExtractor, WhisperTokenizerFast
12
+ # import asyncio
13
+ # from ..batch_processor import AsyncBatchEngine # 修改为你的路径
14
+ # from typing import List, Union, Tuple, Literal, Optional
15
+
16
+ def load_quantize_encoder(model_path):
17
+ config = WhisperVQConfig.from_pretrained(model_path)
18
+ config.quantize_encoder_only = True
19
+ model = WhisperVQEncoder(config)
20
+ state_dict = {}
21
+ for path in glob.glob(os.path.join(model_path, "model*.safetensors")):
22
+ with safetensors.safe_open(path, framework="pt", device="cpu") as f:
23
+ for key in f.keys():
24
+ if key.startswith("model.encoder."):
25
+ new_key = key[len("model.encoder."):]
26
+ if new_key.startswith("layer_norm"):
27
+ continue
28
+ if new_key.startswith("layers"):
29
+ layer_id = int(new_key.split(".")[1])
30
+ if layer_id >= config.quantize_position:
31
+ continue
32
+ state_dict[new_key] = f.get_tensor(key)
33
+ model.load_state_dict(state_dict)
34
+ model.eval()
35
+ model.cuda()
36
+ return model
37
+
38
+
39
+ _resample_buffer: dict[int, torchaudio.transforms.Resample] = {}
40
+
41
+
42
+ def extract_speech_token(model: WhisperVQEncoder, feature_extractor: WhisperFeatureExtractor, utts,device="cuda"):
43
+ with torch.no_grad():
44
+ audios, indices = [], []
45
+ for idx, utt in enumerate(utts):
46
+ if isinstance(utt, tuple):
47
+ audio, sample_rate = utt
48
+ else:
49
+ audio, sample_rate = torchaudio.load(utt)
50
+ audio = audio.to(device)
51
+ if sample_rate != 16000:
52
+ if sample_rate not in _resample_buffer:
53
+ _resample_buffer[sample_rate] = torchaudio.transforms.Resample(
54
+ orig_freq=sample_rate,
55
+ new_freq=16000
56
+ ).to(device)
57
+ audio = _resample_buffer[sample_rate](audio)
58
+ # if audio.shape[0] > 1:
59
+ # audio = audio[:1]
60
+ audio = audio[0]
61
+ audio = audio.cpu().numpy()
62
+ time_step = 0
63
+ while time_step * 16000 < audio.shape[0]:
64
+ audio_segment = audio[time_step * 16000: (time_step + 30) * 16000]
65
+ audios.append(audio_segment)
66
+ indices.append(idx)
67
+ time_step += 30
68
+ pooling_kernel_size = model.config.pooling_kernel_size or 1
69
+ stride = model.conv1.stride[0] * model.conv2.stride[0] * pooling_kernel_size * feature_extractor.hop_length
70
+ all_speech_tokens = [[] for _ in range(len(utts))]
71
+ batch_size = 128
72
+ for start in range(0, len(audios), batch_size):
73
+ features = feature_extractor(audios[start: start + batch_size], sampling_rate=16000,
74
+ return_attention_mask=True, return_tensors="pt", device=device,
75
+ padding="longest", pad_to_multiple_of=stride)
76
+ features = features.to(device=device)
77
+ # ✅ 关键修复:如果模型是FP16,则输入也转为FP16
78
+ if next(model.parameters()).dtype == torch.float16:
79
+ features = {k: v.half() for k, v in features.items()}
80
+ outputs = model(**features)
81
+ speech_tokens = outputs.quantized_token_ids
82
+ attention_mask = features["attention_mask"][:, ::model.conv1.stride[0] * model.conv2.stride[0]]
83
+ attention_mask = attention_mask[:, ::model.config.pooling_kernel_size]
84
+ assert attention_mask.shape == speech_tokens.shape
85
+ for i in range(len(speech_tokens)):
86
+ idx = indices[start + i]
87
+ speech_token = speech_tokens[i][attention_mask[i].bool()].tolist()
88
+ all_speech_tokens[idx].extend(speech_token)
89
+ return all_speech_tokens
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers==4.57.3
2
+ torch==2.8.0
3
+ librosa
4
+ soundfile
5
+ numpy