Spaces:
Running
on
L4
Running
on
L4
wanglamao
commited on
Commit
·
528efee
1
Parent(s):
e55409d
init
Browse files- app.py +206 -0
- data_utils/__init__.py +0 -0
- data_utils/audio_dataset_ark_audio.py +414 -0
- gpa_inference.py +293 -0
- models/__init__.py +0 -0
- models/bicodec_tokenizer/__init__.py +0 -0
- models/bicodec_tokenizer/base_model.py +87 -0
- models/bicodec_tokenizer/batch_processor.py +182 -0
- models/bicodec_tokenizer/models/__init__.py +0 -0
- models/bicodec_tokenizer/models/audio_tokenizer.py +164 -0
- models/bicodec_tokenizer/models/bicodec.py +248 -0
- models/bicodec_tokenizer/modules/blocks/layers.py +73 -0
- models/bicodec_tokenizer/modules/blocks/samper.py +115 -0
- models/bicodec_tokenizer/modules/blocks/vocos.py +373 -0
- models/bicodec_tokenizer/modules/encoder_decoder/feat_decoder.py +115 -0
- models/bicodec_tokenizer/modules/encoder_decoder/feat_encoder.py +107 -0
- models/bicodec_tokenizer/modules/encoder_decoder/wave_generator.py +88 -0
- models/bicodec_tokenizer/modules/fsq/finite_scalar_quantization.py +251 -0
- models/bicodec_tokenizer/modules/fsq/residual_fsq.py +355 -0
- models/bicodec_tokenizer/modules/speaker/__init__.py +0 -0
- models/bicodec_tokenizer/modules/speaker/ecapa_tdnn.py +267 -0
- models/bicodec_tokenizer/modules/speaker/perceiver_encoder.py +360 -0
- models/bicodec_tokenizer/modules/speaker/pooling_layers.py +298 -0
- models/bicodec_tokenizer/modules/speaker/speaker_encoder.py +136 -0
- models/bicodec_tokenizer/modules/vq/factorized_vector_quantize.py +187 -0
- models/bicodec_tokenizer/spark_detokenizer.py +106 -0
- models/bicodec_tokenizer/spark_tokenizer.py +244 -0
- models/bicodec_tokenizer/tokenizer_utils.py +44 -0
- models/bicodec_tokenizer/utils/__init__.py +0 -0
- models/bicodec_tokenizer/utils/audio.py +271 -0
- models/bicodec_tokenizer/utils/file.py +221 -0
- models/bicodec_tokenizer/utils/parse_options.sh +97 -0
- models/bicodec_tokenizer/utils/token_parser.py +187 -0
- models/glm_speech_tokenizer/__init__.py +0 -0
- models/glm_speech_tokenizer/batch_processor.py +182 -0
- models/glm_speech_tokenizer/configuration_whisper.py +37 -0
- models/glm_speech_tokenizer/generation_whisper.py +1828 -0
- models/glm_speech_tokenizer/modeling_whisper.py +0 -0
- models/glm_speech_tokenizer/speech_token_extractor.py +126 -0
- models/glm_speech_tokenizer/test_speech_token_extractor.py +136 -0
- models/glm_speech_tokenizer/utils.py +89 -0
- requirements.txt +5 -0
app.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
import argparse
|
| 6 |
+
import librosa
|
| 7 |
+
import soundfile as sf
|
| 8 |
+
|
| 9 |
+
from gpa_inference import GPAInference
|
| 10 |
+
|
| 11 |
+
# Global inference object placeholder
|
| 12 |
+
inference = None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def preprocess_audio(audio_path):
|
| 16 |
+
"""Ensure audio is 16kHz mono"""
|
| 17 |
+
if not audio_path:
|
| 18 |
+
return None
|
| 19 |
+
try:
|
| 20 |
+
# Load audio with librosa: automatically resamples to sr=16000 and converts to mono
|
| 21 |
+
y, _ = librosa.load(audio_path, sr=16000, mono=True)
|
| 22 |
+
|
| 23 |
+
# Save processed audio to a new file to avoid conflicts
|
| 24 |
+
dir_name = os.path.dirname(audio_path)
|
| 25 |
+
base_name = os.path.basename(audio_path)
|
| 26 |
+
name, ext = os.path.splitext(base_name)
|
| 27 |
+
new_path = os.path.join(dir_name, f"{name}_16k.wav")
|
| 28 |
+
|
| 29 |
+
sf.write(new_path, y, 16000)
|
| 30 |
+
print(f"Preprocessed audio saved to: {new_path}")
|
| 31 |
+
return new_path
|
| 32 |
+
except Exception as e:
|
| 33 |
+
print(f"Error processing audio {audio_path}: {e}")
|
| 34 |
+
return audio_path
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ======================== Interface Call Logic ========================
|
| 38 |
+
|
| 39 |
+
def process_stt(audio_path):
|
| 40 |
+
global inference
|
| 41 |
+
if inference is None:
|
| 42 |
+
return "Model not initialized."
|
| 43 |
+
|
| 44 |
+
if not audio_path:
|
| 45 |
+
return "Please upload audio first."
|
| 46 |
+
|
| 47 |
+
# Preprocess audio
|
| 48 |
+
audio_path = preprocess_audio(audio_path)
|
| 49 |
+
|
| 50 |
+
# Direct inference call
|
| 51 |
+
return inference.run_stt(audio_path=audio_path, do_sample=False)
|
| 52 |
+
|
| 53 |
+
def process_tts_a(text, ref_audio):
|
| 54 |
+
global inference
|
| 55 |
+
if inference is None:
|
| 56 |
+
return None
|
| 57 |
+
|
| 58 |
+
if not text or not ref_audio:
|
| 59 |
+
return None
|
| 60 |
+
|
| 61 |
+
# Preprocess audio
|
| 62 |
+
ref_audio = preprocess_audio(ref_audio)
|
| 63 |
+
|
| 64 |
+
# Direct inference call
|
| 65 |
+
return inference.run_tts(
|
| 66 |
+
task="tts-a",
|
| 67 |
+
output_filename="tts_output.wav",
|
| 68 |
+
text=text,
|
| 69 |
+
ref_audio_path=ref_audio,
|
| 70 |
+
temperature=0.8,
|
| 71 |
+
do_sample=True,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def process_vc(src_audio, ref_audio):
|
| 75 |
+
global inference
|
| 76 |
+
if inference is None:
|
| 77 |
+
return None
|
| 78 |
+
|
| 79 |
+
if not src_audio or not ref_audio:
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
# Preprocess audio
|
| 83 |
+
src_audio = preprocess_audio(src_audio)
|
| 84 |
+
ref_audio = preprocess_audio(ref_audio)
|
| 85 |
+
|
| 86 |
+
# Direct inference call
|
| 87 |
+
return inference.run_vc(
|
| 88 |
+
source_audio_path=src_audio,
|
| 89 |
+
ref_audio_path=ref_audio,
|
| 90 |
+
output_filename="vc_output.wav",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# ======================== Gradio UI Layout ========================
|
| 94 |
+
|
| 95 |
+
# Use a soft, premium theme with indigo/slate colors to replace the default orange
|
| 96 |
+
theme = gr.themes.Soft(
|
| 97 |
+
primary_hue="indigo",
|
| 98 |
+
secondary_hue="slate",
|
| 99 |
+
neutral_hue="slate",
|
| 100 |
+
font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
with gr.Blocks(title="General Purpose Audio System", theme=theme) as demo:
|
| 104 |
+
gr.Markdown("# General Purpose Audio System")
|
| 105 |
+
gr.Markdown("STT, TTS, and VC full-feature demo interface based on GPAEngine.")
|
| 106 |
+
|
| 107 |
+
with gr.Tabs():
|
| 108 |
+
# --- STT Tab ---
|
| 109 |
+
with gr.TabItem("🎙️ Speech to Text (STT)"):
|
| 110 |
+
with gr.Row():
|
| 111 |
+
stt_input = gr.Audio(label="Input Audio", type="filepath")
|
| 112 |
+
stt_output = gr.Textbox(label="Recognition Result", placeholder="Recognition result will be displayed here in real-time...", lines=5)
|
| 113 |
+
stt_btn = gr.Button("Start Recognition", variant="primary")
|
| 114 |
+
stt_btn.click(process_stt, inputs=stt_input, outputs=stt_output)
|
| 115 |
+
|
| 116 |
+
# --- TTS-A Tab ---
|
| 117 |
+
with gr.TabItem("👤 Text to Speech (TTS)"):
|
| 118 |
+
with gr.Row():
|
| 119 |
+
with gr.Column():
|
| 120 |
+
ttsa_text = gr.Textbox(label="Synthesis Text", value="Hello, I am generated by voice cloning.")
|
| 121 |
+
ttsa_ref = gr.Audio(label="Reference Audio (Voice Source)", type="filepath")
|
| 122 |
+
ttsa_output = gr.Audio(label="Synthesis Result")
|
| 123 |
+
ttsa_btn = gr.Button("Synthesize Now", variant="primary")
|
| 124 |
+
ttsa_btn.click(process_tts_a, inputs=[ttsa_text, ttsa_ref], outputs=ttsa_output)
|
| 125 |
+
|
| 126 |
+
# --- VC Tab ---
|
| 127 |
+
with gr.TabItem("🎭 Voice Conversion (VC)"):
|
| 128 |
+
with gr.Row():
|
| 129 |
+
with gr.Column():
|
| 130 |
+
vc_src = gr.Audio(label="Source Audio (Content Source)", type="filepath")
|
| 131 |
+
vc_ref = gr.Audio(label="Reference Audio (Voice Source)", type="filepath")
|
| 132 |
+
vc_output = gr.Audio(label="Conversion Result")
|
| 133 |
+
vc_btn = gr.Button("Start Conversion", variant="primary")
|
| 134 |
+
vc_btn.click(process_vc, inputs=[vc_src, vc_ref], outputs=vc_output)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def parse_args():
|
| 138 |
+
parser = argparse.ArgumentParser(description="GPA Audio System GUI")
|
| 139 |
+
|
| 140 |
+
# Model Paths
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--tokenizer_path",
|
| 143 |
+
type=str,
|
| 144 |
+
default="/data3/gpa_ckpt/gpa_final/glm-4-voice-tokenizer",
|
| 145 |
+
help="Path to GLM4 tokenizer",
|
| 146 |
+
)
|
| 147 |
+
parser.add_argument(
|
| 148 |
+
"--text_tokenizer_path",
|
| 149 |
+
type=str,
|
| 150 |
+
default="/data3/gpa_ckpt/gpa_final",
|
| 151 |
+
help="Path to text tokenizer",
|
| 152 |
+
)
|
| 153 |
+
parser.add_argument(
|
| 154 |
+
"--bicodec_tokenizer_path",
|
| 155 |
+
type=str,
|
| 156 |
+
default="/data3/gpa_ckpt/gpa_final/BiCodec/",
|
| 157 |
+
help="Path to BiCodec tokenizer",
|
| 158 |
+
)
|
| 159 |
+
parser.add_argument(
|
| 160 |
+
"--gpa_model_path",
|
| 161 |
+
type=str,
|
| 162 |
+
default="/data3/gpa_ckpt/gpa_final",
|
| 163 |
+
help="Path to GPA model",
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# System Config
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"--output_dir",
|
| 169 |
+
type=str,
|
| 170 |
+
default="./output_gui",
|
| 171 |
+
help="Directory to save output files",
|
| 172 |
+
)
|
| 173 |
+
parser.add_argument(
|
| 174 |
+
"--device",
|
| 175 |
+
type=str,
|
| 176 |
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
| 177 |
+
help="Device to use",
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Server Config
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--server_name", type=str, default="0.0.0.0", help="Address for Gradio server"
|
| 183 |
+
)
|
| 184 |
+
parser.add_argument(
|
| 185 |
+
"--server_port", type=int, default=7868, help="Port for Gradio server"
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
return parser.parse_args()
|
| 189 |
+
|
| 190 |
+
args = parse_args()
|
| 191 |
+
|
| 192 |
+
# Instantiate Model
|
| 193 |
+
print(f"Initializing GPA Inference System on {args.device}...")
|
| 194 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 195 |
+
|
| 196 |
+
inference = GPAInference(
|
| 197 |
+
tokenizer_path=args.tokenizer_path,
|
| 198 |
+
text_tokenizer_path=args.text_tokenizer_path,
|
| 199 |
+
bicodec_tokenizer_path=args.bicodec_tokenizer_path,
|
| 200 |
+
gpa_model_path=args.gpa_model_path,
|
| 201 |
+
output_dir=args.output_dir,
|
| 202 |
+
device=args.device,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Launch Gradio Demo
|
| 206 |
+
demo.queue().launch()
|
data_utils/__init__.py
ADDED
|
File without changes
|
data_utils/audio_dataset_ark_audio.py
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
from models.bicodec_tokenizer.spark_tokenizer import SparkTokenizer
|
| 4 |
+
from models.glm_speech_tokenizer.speech_token_extractor import SpeechTokenExtractor
|
| 5 |
+
from models.glm_speech_tokenizer.modeling_whisper import WhisperVQEncoder
|
| 6 |
+
from transformers import PreTrainedTokenizer,AutoTokenizer,WhisperFeatureExtractor
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import logging
|
| 10 |
+
from typing import List, Dict, Any, Literal, Optional, Union
|
| 11 |
+
from datasets import load_dataset
|
| 12 |
+
from torch.utils.data import DataLoader
|
| 13 |
+
|
| 14 |
+
def has_punctuation(text: str) -> bool:
|
| 15 |
+
# 包含中英文符号
|
| 16 |
+
pattern = r"[,。!?;:()“”‘’、,.!?;:()\[\]{}\"']"
|
| 17 |
+
return bool(re.search(pattern, text))
|
| 18 |
+
|
| 19 |
+
ALL_TASKS = ["stt", "tts-a", "vc"]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ark_infer_processor:
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
glm_tokenizer: SpeechTokenExtractor,
|
| 26 |
+
bicodec_tokenizer: SparkTokenizer,
|
| 27 |
+
text_tokenizer: PreTrainedTokenizer,
|
| 28 |
+
max_length: int = 512,
|
| 29 |
+
glm_semantic_token_offset: int = 151727,
|
| 30 |
+
semantic_token_offset: int = 172207,
|
| 31 |
+
global_token_offset: int = 168111,
|
| 32 |
+
audio_path_name: str = "audio",
|
| 33 |
+
device: str = "cpu",
|
| 34 |
+
):
|
| 35 |
+
self.glm_tokenizer = glm_tokenizer
|
| 36 |
+
self.bicodec_tokenizer = bicodec_tokenizer
|
| 37 |
+
self.text_tokenizer = text_tokenizer
|
| 38 |
+
self.max_length = max_length
|
| 39 |
+
self.glm_semantic_token_offset = glm_semantic_token_offset
|
| 40 |
+
self.semantic_token_offset = semantic_token_offset
|
| 41 |
+
self.global_token_offset = global_token_offset
|
| 42 |
+
self.device = device
|
| 43 |
+
self.audio_path_name = audio_path_name
|
| 44 |
+
|
| 45 |
+
def _process_example_stt(self, audio_path: str):
|
| 46 |
+
|
| 47 |
+
##target 音频
|
| 48 |
+
with torch.no_grad():
|
| 49 |
+
glm_semantic_tokens = self.glm_tokenizer.extract([audio_path])
|
| 50 |
+
glm_semantic_tokens = torch.as_tensor(glm_semantic_tokens, device="cpu", dtype=torch.long)
|
| 51 |
+
|
| 52 |
+
semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens']
|
| 53 |
+
glm_semantic_tokens_list = (
|
| 54 |
+
(glm_semantic_tokens + self.glm_semantic_token_offset).cpu().tolist()[0]
|
| 55 |
+
)
|
| 56 |
+
semantic_tokens_list = (
|
| 57 |
+
(semantic_tokens + self.semantic_token_offset).cpu().tolist()[0]
|
| 58 |
+
)
|
| 59 |
+
input_ids = (
|
| 60 |
+
self.text_tokenizer.encode("<|start_glm_token|>")
|
| 61 |
+
+ glm_semantic_tokens_list
|
| 62 |
+
+ self.text_tokenizer.encode("<|end_glm_token|>")
|
| 63 |
+
+ self.text_tokenizer.encode("<|start_semantic_token|>")
|
| 64 |
+
+ semantic_tokens_list
|
| 65 |
+
+ self.text_tokenizer.encode("<|end_semantic_token|>")
|
| 66 |
+
+ self.text_tokenizer.encode("<|start_content|>")
|
| 67 |
+
)
|
| 68 |
+
attention_mask = [1] * (len(input_ids))
|
| 69 |
+
return input_ids, attention_mask
|
| 70 |
+
|
| 71 |
+
def _process_example_tts_a(self, text: str, ref_audio_path: str):
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])['global_tokens']
|
| 74 |
+
all_text = "<|start_content|>" + text + "<|end_content|>"
|
| 75 |
+
global_tokens_list = (
|
| 76 |
+
(global_tokens + self.global_token_offset).cpu().tolist()[0][0]
|
| 77 |
+
)
|
| 78 |
+
text_tokens = self.text_tokenizer(
|
| 79 |
+
all_text, truncation=True, max_length=self.max_length
|
| 80 |
+
)
|
| 81 |
+
input_ids = (
|
| 82 |
+
self.text_tokenizer.encode("<|start_global_token|>")
|
| 83 |
+
+ global_tokens_list
|
| 84 |
+
+ self.text_tokenizer.encode("<|end_global_token|>")
|
| 85 |
+
+ text_tokens["input_ids"]
|
| 86 |
+
)
|
| 87 |
+
attention_mask = [1] * len(input_ids)
|
| 88 |
+
return input_ids, attention_mask
|
| 89 |
+
|
| 90 |
+
def _process_example_vc(self, audio_path: str, ref_audio_path: str):
|
| 91 |
+
with torch.no_grad():
|
| 92 |
+
semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens']
|
| 93 |
+
new_global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])['global_tokens']
|
| 94 |
+
semantic_tokens_list = (
|
| 95 |
+
(semantic_tokens + self.semantic_token_offset).cpu().tolist()[0]
|
| 96 |
+
)
|
| 97 |
+
new_global_tokens_list = (
|
| 98 |
+
(new_global_tokens + self.global_token_offset).cpu().tolist()[0][0]
|
| 99 |
+
)
|
| 100 |
+
all_str = (
|
| 101 |
+
"<|start_global_token|>"
|
| 102 |
+
+ self.text_tokenizer.decode(new_global_tokens_list)
|
| 103 |
+
+ "<|end_global_token|>"
|
| 104 |
+
+ "<|start_semantic_token|>"
|
| 105 |
+
+ self.text_tokenizer.decode(semantic_tokens_list)
|
| 106 |
+
+ "<|end_semantic_token|>"
|
| 107 |
+
+ "<|end_content|>"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
inputs = self.text_tokenizer(all_str)
|
| 111 |
+
input_ids = inputs["input_ids"]
|
| 112 |
+
attention_mask = inputs["attention_mask"]
|
| 113 |
+
return input_ids, attention_mask
|
| 114 |
+
|
| 115 |
+
def process_input(
|
| 116 |
+
self,
|
| 117 |
+
task: Literal["stt", "tts-a", "vc"],
|
| 118 |
+
audio_path: str | None = None,
|
| 119 |
+
ref_audio_path: str | None = None,
|
| 120 |
+
text: str | None = None,
|
| 121 |
+
):
|
| 122 |
+
"""加载指定音频、特征并根据任务类型返回 token 化结果。"""
|
| 123 |
+
|
| 124 |
+
if task == "stt":
|
| 125 |
+
assert audio_path is not None
|
| 126 |
+
input_ids, attention_mask = self._process_example_stt(audio_path)
|
| 127 |
+
elif task == "tts-a":
|
| 128 |
+
assert ref_audio_path is not None and text is not None
|
| 129 |
+
input_ids, attention_mask = self._process_example_tts_a(
|
| 130 |
+
text, ref_audio_path
|
| 131 |
+
)
|
| 132 |
+
elif task == "vc":
|
| 133 |
+
assert audio_path is not None and ref_audio_path is not None
|
| 134 |
+
input_ids, attention_mask = self._process_example_vc(
|
| 135 |
+
audio_path, ref_audio_path
|
| 136 |
+
)
|
| 137 |
+
else:
|
| 138 |
+
raise ValueError(
|
| 139 |
+
f"Unsupported task: {task}, all supported tasks: {ALL_TASKS}"
|
| 140 |
+
)
|
| 141 |
+
return {
|
| 142 |
+
"input_ids": input_ids,
|
| 143 |
+
"attention_mask": attention_mask,
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class ark_processor:
|
| 148 |
+
def __init__(self,
|
| 149 |
+
glm_tokenizer: SpeechTokenExtractor,
|
| 150 |
+
bicodec_tokenizer: SparkTokenizer,
|
| 151 |
+
text_tokenizer:PreTrainedTokenizer,
|
| 152 |
+
max_length:int = 512,
|
| 153 |
+
glm_semantic_token_offset:int = 151727,
|
| 154 |
+
semantic_token_offset: int =172207,
|
| 155 |
+
global_token_offset: int =168111,
|
| 156 |
+
audio_path_name:str = "audio",
|
| 157 |
+
device:str ='cpu'):
|
| 158 |
+
self.glm_tokenizer = glm_tokenizer
|
| 159 |
+
self.bicodec_tokenizer = bicodec_tokenizer
|
| 160 |
+
self.text_tokenizer = text_tokenizer
|
| 161 |
+
self.max_length = max_length
|
| 162 |
+
self.glm_semantic_token_offset =glm_semantic_token_offset
|
| 163 |
+
self.semantic_token_offset=semantic_token_offset
|
| 164 |
+
self.global_token_offset=global_token_offset
|
| 165 |
+
self.device = device
|
| 166 |
+
self.audio_path_name =audio_path_name
|
| 167 |
+
|
| 168 |
+
def process_example(self, example: Dict[str, Any]):
|
| 169 |
+
"""
|
| 170 |
+
这个函数由多个CPU进程并行执行。
|
| 171 |
+
它负责加载、重采样和对单个样本进行特征提取/分词。
|
| 172 |
+
"""
|
| 173 |
+
task = example.get("task", "stt")
|
| 174 |
+
audio_path = example.get(self.audio_path_name, "")
|
| 175 |
+
ref_audio_path = example.get("ref_audio", "")
|
| 176 |
+
vc_audio = example.get("vc_audio", "")
|
| 177 |
+
text = example.get("text", "")
|
| 178 |
+
|
| 179 |
+
if task == "stt":
|
| 180 |
+
##target 音频
|
| 181 |
+
with torch.no_grad():
|
| 182 |
+
glm_semantic_tokens = self.glm_tokenizer.extract([audio_path])
|
| 183 |
+
glm_semantic_tokens = torch.as_tensor(glm_semantic_tokens, device="cpu", dtype=torch.long)
|
| 184 |
+
|
| 185 |
+
semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens']
|
| 186 |
+
glm_semantic_tokens_list = (glm_semantic_tokens + self.glm_semantic_token_offset).cpu().tolist()[0]
|
| 187 |
+
semantic_tokens_list = (semantic_tokens + self.semantic_token_offset).cpu().tolist()[0]
|
| 188 |
+
# print(f"len of semantic is {len(semantic_tokens_list)}")
|
| 189 |
+
##对text进行token
|
| 190 |
+
text_tokens = self.text_tokenizer(text, truncation=True, max_length=self.max_length)
|
| 191 |
+
|
| 192 |
+
input_ids = self.text_tokenizer.encode("<|start_glm_token|>") + glm_semantic_tokens_list + self.text_tokenizer.encode("<|end_glm_token|>") \
|
| 193 |
+
+ self.text_tokenizer.encode("<|start_semantic_token|>") + semantic_tokens_list + self.text_tokenizer.encode(
|
| 194 |
+
"<|end_semantic_token|>") \
|
| 195 |
+
+ self.text_tokenizer.encode("<|start_content|>") + text_tokens["input_ids"] + self.text_tokenizer.encode("<|end_content|>") \
|
| 196 |
+
+ self.text_tokenizer.encode("<|im_end|>")
|
| 197 |
+
attention_mask = [1] * (len(input_ids))
|
| 198 |
+
labels = [-100] * (len(semantic_tokens_list) + 5 + len(glm_semantic_tokens_list)) + text_tokens["input_ids"] + self.text_tokenizer.encode(
|
| 199 |
+
"<|end_content|>") + self.text_tokenizer.encode("<|im_end|>")
|
| 200 |
+
|
| 201 |
+
elif task == "tts-a":
|
| 202 |
+
with torch.no_grad():
|
| 203 |
+
semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens']
|
| 204 |
+
global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])['global_tokens']
|
| 205 |
+
all_text = "<|start_content|>" + text + "<|end_content|>"
|
| 206 |
+
global_tokens_list = (global_tokens + self.global_token_offset).cpu().tolist()[0][0]
|
| 207 |
+
text_tokens = self.text_tokenizer(all_text, truncation=True, max_length=self.max_length)
|
| 208 |
+
semantic_tokens_list = (semantic_tokens + self.semantic_token_offset).cpu().tolist()[0]
|
| 209 |
+
input_ids = self.text_tokenizer.encode("<|start_global_token|>") + global_tokens_list + self.text_tokenizer.encode(
|
| 210 |
+
"<|end_global_token|>") + text_tokens["input_ids"] + semantic_tokens_list + self.text_tokenizer.encode("<|im_end|>")
|
| 211 |
+
attention_mask = [1] * len(input_ids)
|
| 212 |
+
labels = [-100] * (len(text_tokens["input_ids"]) + 2 + len(global_tokens_list)) + semantic_tokens_list + self.text_tokenizer.encode("<|im_end|>")
|
| 213 |
+
|
| 214 |
+
elif task == "vc":
|
| 215 |
+
with torch.no_grad():
|
| 216 |
+
semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens']
|
| 217 |
+
global_tokens = self.bicodec_tokenizer.tokenize([audio_path])['global_tokens']
|
| 218 |
+
# global_tokens, semantic_tokens=self.bicodec_tokenizer.tokenize(audio_path=audio_path)
|
| 219 |
+
# new_global_tokens, new_semantic_tokens=self.bicodec_tokenizer.tokenize(vc_audio,ref_audio_path)
|
| 220 |
+
new_semantic_tokens = self.bicodec_tokenizer.tokenize([vc_audio])['semantic_tokens']
|
| 221 |
+
new_global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])['global_tokens']
|
| 222 |
+
|
| 223 |
+
global_tokens_list = (global_tokens + self.global_token_offset).cpu().tolist()[0][0]
|
| 224 |
+
semantic_tokens_list = (semantic_tokens + self.semantic_token_offset).cpu().tolist()[0]
|
| 225 |
+
new_global_tokens_list = (new_global_tokens + self.global_token_offset).cpu().tolist()[0][0]
|
| 226 |
+
new_semantic_tokens_list = (new_semantic_tokens + self.semantic_token_offset).cpu().tolist()[0]
|
| 227 |
+
all_str = "<|start_global_token|>" + self.text_tokenizer.decode(new_global_tokens_list) + "<|end_global_token|>" + "<|start_semantic_token|>" + self.text_tokenizer.decode(
|
| 228 |
+
semantic_tokens_list) + "<|end_semantic_token|>" + "<|end_content|>" + self.text_tokenizer.decode(new_semantic_tokens_list) + "<|im_end|>"
|
| 229 |
+
|
| 230 |
+
##add token and mask
|
| 231 |
+
inputs = self.text_tokenizer(all_str)
|
| 232 |
+
input_ids = inputs['input_ids']
|
| 233 |
+
attention_mask = inputs['attention_mask']
|
| 234 |
+
labels = [-100] * (5 + len(new_global_tokens_list) + len(semantic_tokens_list)) + new_semantic_tokens_list + self.text_tokenizer.encode("<|im_end|>")
|
| 235 |
+
else:
|
| 236 |
+
##默认走stt
|
| 237 |
+
with torch.no_grad():
|
| 238 |
+
glm_semantic_tokens = self.glm_tokenizer.extract([audio_path])
|
| 239 |
+
glm_semantic_tokens = torch.as_tensor(glm_semantic_tokens, device="cpu", dtype=torch.long)
|
| 240 |
+
|
| 241 |
+
semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens']
|
| 242 |
+
glm_semantic_tokens_list = (glm_semantic_tokens+self.glm_semantic_token_offset).cpu().tolist()[0]
|
| 243 |
+
semantic_tokens_list = (semantic_tokens+self.semantic_token_offset).cpu().tolist()[0]
|
| 244 |
+
# print(f"len of semantic is {len(semantic_tokens_list)}")
|
| 245 |
+
##对text进行token
|
| 246 |
+
text_tokens = self.text_tokenizer(text, truncation=True, max_length=self.max_length)
|
| 247 |
+
|
| 248 |
+
input_ids = self.text_tokenizer.encode("<|start_glm_token|>")+ glm_semantic_tokens_list + self.text_tokenizer.encode("<|end_glm_token|>") \
|
| 249 |
+
+ self.text_tokenizer.encode("<|start_semantic_token|>")+ semantic_tokens_list + self.text_tokenizer.encode("<|end_semantic_token|>") \
|
| 250 |
+
+ text_tokens["input_ids"] \
|
| 251 |
+
+ self.text_tokenizer.encode("<|im_end|>")
|
| 252 |
+
attention_mask = [1]*(len(semantic_tokens_list)+4+len(glm_semantic_tokens_list)) +text_tokens["attention_mask"] +[1]
|
| 253 |
+
labels = [-100]*(len(semantic_tokens_list)+4+len(glm_semantic_tokens_list))+ text_tokens["input_ids"]+ self.text_tokenizer.encode("<|im_end|>")
|
| 254 |
+
return {
|
| 255 |
+
"input_ids": input_ids,
|
| 256 |
+
"attention_mask": attention_mask,
|
| 257 |
+
"labels": labels,
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def create_tts_collate_fn(
|
| 262 |
+
pad_token_id: int,
|
| 263 |
+
processor, # ark_processor
|
| 264 |
+
max_length: Optional[int]=None,# 传入你想要的截断上限,例如 512
|
| 265 |
+
truncation_side: str = "right" # "right" 或 "left",默认右截断
|
| 266 |
+
):
|
| 267 |
+
"""
|
| 268 |
+
手动填充 + 可选截断的 collate_fn 工厂。
|
| 269 |
+
|
| 270 |
+
参数:
|
| 271 |
+
pad_token_id: 用于 input_ids 的 pad 值
|
| 272 |
+
processor: 你的 ark_processor,需提供 .process_example()
|
| 273 |
+
max_length: 若提供,则对每个样本在拼批前先截断到该长度
|
| 274 |
+
truncation_side: "right" | "left",决定从哪侧截断
|
| 275 |
+
"""
|
| 276 |
+
label_pad_value = -100
|
| 277 |
+
attention_mask_pad_value = 0
|
| 278 |
+
|
| 279 |
+
def _truncate_1d(x: torch.Tensor, keep_len: int, side: str) -> torch.Tensor:
|
| 280 |
+
if x.numel() <= keep_len:
|
| 281 |
+
return x
|
| 282 |
+
if side == "right":
|
| 283 |
+
return x[:keep_len]
|
| 284 |
+
elif side == "left":
|
| 285 |
+
return x[-keep_len:]
|
| 286 |
+
else:
|
| 287 |
+
raise ValueError(f"Unsupported truncation_side: {side}")
|
| 288 |
+
|
| 289 |
+
def _to_long_tensor(x) -> torch.Tensor:
|
| 290 |
+
if isinstance(x, torch.Tensor):
|
| 291 |
+
return x.detach().clone().long()
|
| 292 |
+
return torch.tensor(x, dtype=torch.long)
|
| 293 |
+
|
| 294 |
+
def collate_fn(examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
| 295 |
+
# 1) 预处理(过滤空样本)
|
| 296 |
+
proc = [processor.process_example(ex) for ex in examples if ex]
|
| 297 |
+
proc = [d for d in proc if d and ("input_ids" in d) and ("attention_mask" in d) and ("labels" in d)]
|
| 298 |
+
|
| 299 |
+
if len(proc) == 0:
|
| 300 |
+
# 返回空批,避免 DataLoader 崩溃
|
| 301 |
+
return {
|
| 302 |
+
"input_ids": torch.empty(0, dtype=torch.long),
|
| 303 |
+
"attention_mask": torch.empty(0, dtype=torch.long),
|
| 304 |
+
"labels": torch.empty(0, dtype=torch.long),
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
# 2) 样本级截断(如果设置了 max_length)
|
| 308 |
+
if max_length is not None:
|
| 309 |
+
trimmed = []
|
| 310 |
+
for ex in proc:
|
| 311 |
+
ids = _to_long_tensor(ex["input_ids"])
|
| 312 |
+
mask = _to_long_tensor(ex["attention_mask"])
|
| 313 |
+
labs = _to_long_tensor(ex["labels"])
|
| 314 |
+
|
| 315 |
+
keep_len = min(max_length, ids.numel())
|
| 316 |
+
ids = _truncate_1d(ids, keep_len, truncation_side)
|
| 317 |
+
mask = _truncate_1d(mask, keep_len, truncation_side)
|
| 318 |
+
labs = _truncate_1d(labs, keep_len, truncation_side)
|
| 319 |
+
|
| 320 |
+
trimmed.append({"input_ids": ids, "attention_mask": mask, "labels": labs})
|
| 321 |
+
proc = trimmed
|
| 322 |
+
|
| 323 |
+
# 3) 计算本批最大长度(截断后再取最大)
|
| 324 |
+
max_len_in_batch = max(int(len(ex["input_ids"])) for ex in proc)
|
| 325 |
+
|
| 326 |
+
# 4) 逐样本右侧 pad 到 batch 最大长度
|
| 327 |
+
padded_input_ids_list = []
|
| 328 |
+
padded_attention_mask_list = []
|
| 329 |
+
padded_labels_list = []
|
| 330 |
+
|
| 331 |
+
for ex in proc:
|
| 332 |
+
ids = _to_long_tensor(ex["input_ids"])
|
| 333 |
+
mask = _to_long_tensor(ex["attention_mask"])
|
| 334 |
+
labs = _to_long_tensor(ex["labels"])
|
| 335 |
+
|
| 336 |
+
need = max_len_in_batch - ids.numel()
|
| 337 |
+
if need < 0:
|
| 338 |
+
# 极端情况:有人为 max_length=None 时超长样本溢出
|
| 339 |
+
keep_len = max_len_in_batch
|
| 340 |
+
ids = _truncate_1d(ids, keep_len, "right")
|
| 341 |
+
mask = _truncate_1d(mask, keep_len, "right")
|
| 342 |
+
labs = _truncate_1d(labs, keep_len, "right")
|
| 343 |
+
need = 0
|
| 344 |
+
|
| 345 |
+
pad_dims = (0, need)
|
| 346 |
+
ids = F.pad(ids, pad_dims, mode="constant", value=pad_token_id)
|
| 347 |
+
mask = F.pad(mask, pad_dims, mode="constant", value=attention_mask_pad_value)
|
| 348 |
+
labs = F.pad(labs, pad_dims, mode="constant", value=label_pad_value)
|
| 349 |
+
|
| 350 |
+
padded_input_ids_list.append(ids)
|
| 351 |
+
padded_attention_mask_list.append(mask)
|
| 352 |
+
padded_labels_list.append(labs)
|
| 353 |
+
|
| 354 |
+
# 5) 堆叠成批
|
| 355 |
+
batch = {
|
| 356 |
+
"input_ids": torch.stack(padded_input_ids_list, dim=0),
|
| 357 |
+
"attention_mask": torch.stack(padded_attention_mask_list, dim=0),
|
| 358 |
+
"labels": torch.stack(padded_labels_list, dim=0),
|
| 359 |
+
}
|
| 360 |
+
return batch
|
| 361 |
+
|
| 362 |
+
return collate_fn
|
| 363 |
+
|
| 364 |
+
if __name__ == "__main__":
|
| 365 |
+
device = "cuda:0"
|
| 366 |
+
bicodec_audio_tokenizer_path = "/data/arki_production/model/SparkAudio/Spark-TTS-0___5B/"
|
| 367 |
+
glm_speech_tokenizer_path = "/data/yumu/model/glm-4-voice-tokenizer"
|
| 368 |
+
feature_extractor = WhisperFeatureExtractor.from_pretrained(glm_speech_tokenizer_path)
|
| 369 |
+
audio_model = WhisperVQEncoder.from_pretrained(glm_speech_tokenizer_path).eval().to(device)
|
| 370 |
+
glm_tokenizer = SpeechTokenExtractor(model=audio_model, feature_extractor=feature_extractor, device=device)
|
| 371 |
+
|
| 372 |
+
text_tokenizer = AutoTokenizer.from_pretrained("/data/yumu/model/ark_audio_v1_0_3_b",trust_remote_code=True)
|
| 373 |
+
bicodec_tokenizer = SparkTokenizer(model_path=bicodec_audio_tokenizer_path, device=device)
|
| 374 |
+
# 配置项
|
| 375 |
+
DATASET_PATH = "/data/yumu/glm_asr_vllm/test/data/test_meeting.jsonl"
|
| 376 |
+
MAX_LENGTH = 4096
|
| 377 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 378 |
+
|
| 379 |
+
print(f"将使用设备: {DEVICE}")
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
# --- 2. 加载流式数据集 ---
|
| 383 |
+
|
| 384 |
+
print(f"以流式方式加载数据集 '{DATASET_PATH}'...")
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
streaming_dataset = load_dataset("json", data_files=DATASET_PATH, streaming=True)['train']
|
| 388 |
+
# --- 4. 构建数据处理流水线 (Pipeline) ---
|
| 389 |
+
|
| 390 |
+
print("正在对数据流进行shuffle,buffer_size=1000...")
|
| 391 |
+
shuffled_dataset = streaming_dataset.shuffle(buffer_size=10000, seed=42)
|
| 392 |
+
processor = ark_processor(
|
| 393 |
+
glm_tokenizer=glm_tokenizer,
|
| 394 |
+
bicodec_tokenizer=bicodec_tokenizer,
|
| 395 |
+
text_tokenizer=text_tokenizer,
|
| 396 |
+
device = DEVICE,
|
| 397 |
+
audio_path_name="audio")
|
| 398 |
+
collate_fn = create_tts_collate_fn(text_tokenizer.pad_token_id,processor,max_length=4096)
|
| 399 |
+
# 创建最终的DataLoader
|
| 400 |
+
data_loader = DataLoader(
|
| 401 |
+
shuffled_dataset,
|
| 402 |
+
batch_size=10, # 根据你的GPU显存和模型大小调整
|
| 403 |
+
collate_fn=collate_fn,
|
| 404 |
+
num_workers=0 # DataLoader的worker,负责从打乱后的流中拉取数据
|
| 405 |
+
)
|
| 406 |
+
print("\n--- 高性能流式 DataLoader 演示 ---")
|
| 407 |
+
print("将从DataLoader中获取并展示第一个批次的数据:\n")
|
| 408 |
+
first_batch = next(iter(data_loader))
|
| 409 |
+
|
| 410 |
+
print("成功获取第一个批次!数据已在collate_fn中填充。")
|
| 411 |
+
for key, value in first_batch.items():
|
| 412 |
+
if value is not None:
|
| 413 |
+
# print(f" - {key}: shape={value.shape}, dtype={value.dtype}")
|
| 414 |
+
print(f" - {key}: shape={value.shape}, dtype={value}")
|
gpa_inference.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
import soundfile as sf
|
| 5 |
+
import re
|
| 6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, WhisperFeatureExtractor
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from models.bicodec_tokenizer.spark_tokenizer import SparkTokenizer
|
| 10 |
+
from models.bicodec_tokenizer.spark_detokenizer import SparkDeTokenizer
|
| 11 |
+
|
| 12 |
+
from models.glm_speech_tokenizer.speech_token_extractor import SpeechTokenExtractor
|
| 13 |
+
from models.glm_speech_tokenizer.modeling_whisper import WhisperVQEncoder
|
| 14 |
+
|
| 15 |
+
from data_utils.audio_dataset_ark_audio import ark_infer_processor
|
| 16 |
+
|
| 17 |
+
class GPAInference:
|
| 18 |
+
def __init__(self, tokenizer_path, text_tokenizer_path, bicodec_tokenizer_path, gpa_model_path, output_dir, device):
|
| 19 |
+
self.tokenizer_path = tokenizer_path
|
| 20 |
+
self.text_tokenizer_path = text_tokenizer_path
|
| 21 |
+
self.bicodec_tokenizer_path = bicodec_tokenizer_path
|
| 22 |
+
self.gpa_model_path = gpa_model_path
|
| 23 |
+
self.output_dir = output_dir
|
| 24 |
+
self.device = device
|
| 25 |
+
|
| 26 |
+
print(f"Using device: {self.device}")
|
| 27 |
+
self._load_models()
|
| 28 |
+
|
| 29 |
+
def _load_models(self):
|
| 30 |
+
print("Loading tokenizers...")
|
| 31 |
+
feature_extractor = WhisperFeatureExtractor.from_pretrained(self.tokenizer_path)
|
| 32 |
+
audio_model = WhisperVQEncoder.from_pretrained(self.tokenizer_path).eval().to(self.device)
|
| 33 |
+
self.glm_tokenizer = SpeechTokenExtractor(model=audio_model, feature_extractor=feature_extractor, device=self.device)
|
| 34 |
+
self.text_tokenizer = AutoTokenizer.from_pretrained(
|
| 35 |
+
self.text_tokenizer_path,
|
| 36 |
+
trust_remote_code=True
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
self.bicodec_tokenizer = SparkTokenizer(model_path=self.bicodec_tokenizer_path, device=self.device)
|
| 40 |
+
self.bicodec_detokenizer = SparkDeTokenizer(model_path=self.bicodec_tokenizer_path, device=self.device)
|
| 41 |
+
self.processor = ark_infer_processor(
|
| 42 |
+
glm_tokenizer=self.glm_tokenizer,
|
| 43 |
+
bicodec_tokenizer=self.bicodec_tokenizer,
|
| 44 |
+
text_tokenizer=self.text_tokenizer,
|
| 45 |
+
device=self.device,
|
| 46 |
+
audio_path_name="audio",
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
print("Loading model...")
|
| 50 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 51 |
+
self.gpa_model_path,
|
| 52 |
+
trust_remote_code=True
|
| 53 |
+
).to(self.device)
|
| 54 |
+
|
| 55 |
+
def generate(self, inputs, **kwargs):
|
| 56 |
+
"""
|
| 57 |
+
Base generation method that accepts dynamic generation parameters.
|
| 58 |
+
"""
|
| 59 |
+
for k in inputs:
|
| 60 |
+
if isinstance(inputs[k], (list, np.ndarray)):
|
| 61 |
+
inputs[k] = torch.tensor(inputs[k]).unsqueeze(0).to(self.device)
|
| 62 |
+
elif isinstance(inputs[k], torch.Tensor):
|
| 63 |
+
inputs[k] = inputs[k].unsqueeze(0).to(self.device)
|
| 64 |
+
|
| 65 |
+
# Default generation config
|
| 66 |
+
generation_config = {
|
| 67 |
+
"max_new_tokens": 1000,
|
| 68 |
+
"do_sample": False,
|
| 69 |
+
"eos_token_id": self.text_tokenizer.convert_tokens_to_ids("<|im_end|>"),
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
# Override defaults with any passed kwargs
|
| 73 |
+
generation_config.update(kwargs)
|
| 74 |
+
|
| 75 |
+
# Remove keys that might be None if passed from args mistakenly
|
| 76 |
+
generation_config = {k: v for k, v in generation_config.items() if v is not None}
|
| 77 |
+
print(f"Generation config: {generation_config}")
|
| 78 |
+
|
| 79 |
+
outputs = self.model.generate(
|
| 80 |
+
input_ids=inputs["input_ids"],
|
| 81 |
+
attention_mask=inputs["attention_mask"],
|
| 82 |
+
**generation_config
|
| 83 |
+
)
|
| 84 |
+
return outputs
|
| 85 |
+
|
| 86 |
+
def run_stt(self, audio_path, **kwargs):
|
| 87 |
+
if not audio_path:
|
| 88 |
+
raise ValueError("audio_path is required for STT")
|
| 89 |
+
|
| 90 |
+
print("\n--- Speech to Text (STT) ---")
|
| 91 |
+
|
| 92 |
+
inputs = self.processor.process_input(
|
| 93 |
+
task="stt",
|
| 94 |
+
audio_path=audio_path,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# recommend hyperparameters for TTS
|
| 98 |
+
kwargs = {
|
| 99 |
+
"max_new_tokens": 512,
|
| 100 |
+
"do_sample": False,
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
# Pass generation arguments (temperature, etc.) to generate
|
| 104 |
+
outputs = self.generate(inputs, **kwargs)
|
| 105 |
+
text = self.text_tokenizer.decode(outputs[0].tolist())
|
| 106 |
+
|
| 107 |
+
if "<|start_content|>" in text:
|
| 108 |
+
return text.split("<|start_content|>")[1].replace("<|im_end|>","").replace("<|end_content|>","")
|
| 109 |
+
else:
|
| 110 |
+
return text.replace("<|im_end|>","")
|
| 111 |
+
|
| 112 |
+
def run_tts(self, task, output_filename, text, ref_audio_path, **kwargs):
|
| 113 |
+
"""
|
| 114 |
+
gen_kwargs: dict, parameters for model.generate (temp, top_p, etc.)
|
| 115 |
+
"""
|
| 116 |
+
if not text:
|
| 117 |
+
raise ValueError("text is required for TTS")
|
| 118 |
+
|
| 119 |
+
# Check ref_audio_path requirement based on task
|
| 120 |
+
if task == "tts-a" and not ref_audio_path:
|
| 121 |
+
raise ValueError(f"ref_audio_path is required for {task}")
|
| 122 |
+
|
| 123 |
+
# recommend hyperparameters for TTS
|
| 124 |
+
kwargs = {
|
| 125 |
+
"max_new_tokens": 512,
|
| 126 |
+
"temperature": 0.2,
|
| 127 |
+
"repetition_penalty": 1.2,
|
| 128 |
+
"do_sample": True,
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
print(f"\n--- {task.upper()} ---")
|
| 132 |
+
output_path = os.path.join(self.output_dir, output_filename)
|
| 133 |
+
|
| 134 |
+
# Pass processor specific args (e.g. emotion, pitch) here
|
| 135 |
+
inputs = self.processor.process_input(
|
| 136 |
+
task=task,
|
| 137 |
+
ref_audio_path=ref_audio_path,
|
| 138 |
+
text=text,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Pass generation specific args (e.g. temperature) here
|
| 142 |
+
# Note: Original code hardcoded temperature=0.8 for TTS, we use gen_kwargs or fallback to generate defaults
|
| 143 |
+
outputs = self.generate(inputs, **kwargs)
|
| 144 |
+
|
| 145 |
+
text_output = self.text_tokenizer.decode(outputs[0].tolist())
|
| 146 |
+
|
| 147 |
+
if "<|end_content|>" in text_output:
|
| 148 |
+
content = text_output.split("<|end_content|>")[1]
|
| 149 |
+
else:
|
| 150 |
+
print("Warning: <|end_content|> not found")
|
| 151 |
+
content = text_output
|
| 152 |
+
|
| 153 |
+
audio_ids = re.findall(r"<\|bicodec_semantic_(\d+)\|>", content)
|
| 154 |
+
audio_list = [int(x) for x in audio_ids]
|
| 155 |
+
|
| 156 |
+
if ref_audio_path:
|
| 157 |
+
global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])['global_tokens']
|
| 158 |
+
else:
|
| 159 |
+
global_tokens = torch.zeros((1, 32), dtype=torch.long).to(self.device)
|
| 160 |
+
|
| 161 |
+
req = {
|
| 162 |
+
"global_tokens": global_tokens,
|
| 163 |
+
"semantic_tokens": torch.tensor(audio_list).unsqueeze(0).to(self.device),
|
| 164 |
+
}
|
| 165 |
+
out = self.bicodec_detokenizer.detokenize(**req)
|
| 166 |
+
reconstructed_wav = out.detach().cpu().float().squeeze().numpy()
|
| 167 |
+
# Simple DC offset removal
|
| 168 |
+
if reconstructed_wav.size > 0:
|
| 169 |
+
reconstructed_wav -= reconstructed_wav.mean()
|
| 170 |
+
|
| 171 |
+
sf.write(output_path, reconstructed_wav, 16000)
|
| 172 |
+
print(f"Saved output to {output_path}")
|
| 173 |
+
return 16000, reconstructed_wav
|
| 174 |
+
|
| 175 |
+
def run_vc(
|
| 176 |
+
self,
|
| 177 |
+
source_audio_path,
|
| 178 |
+
ref_audio_path,
|
| 179 |
+
output_filename="output_gpa_vc.wav",
|
| 180 |
+
**kwargs,
|
| 181 |
+
):
|
| 182 |
+
if not source_audio_path:
|
| 183 |
+
raise ValueError("source_audio_path is required for VC")
|
| 184 |
+
if not ref_audio_path:
|
| 185 |
+
raise ValueError("ref_audio_path is required for VC")
|
| 186 |
+
|
| 187 |
+
print("\n--- Voice Conversion (VC) ---")
|
| 188 |
+
output_path = os.path.join(self.output_dir, output_filename)
|
| 189 |
+
|
| 190 |
+
inputs = self.processor.process_input(
|
| 191 |
+
task="vc",
|
| 192 |
+
audio_path=source_audio_path,
|
| 193 |
+
ref_audio_path=ref_audio_path,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
outputs = self.generate(inputs, **kwargs)
|
| 197 |
+
text_output = self.text_tokenizer.decode(outputs[0].tolist())
|
| 198 |
+
|
| 199 |
+
if "<|end_content|>" in text_output:
|
| 200 |
+
content = text_output.split("<|end_content|>")[1]
|
| 201 |
+
else:
|
| 202 |
+
content = text_output
|
| 203 |
+
|
| 204 |
+
audio_ids = re.findall(r"<\|bicodec_semantic_(\d+)\|>", content)
|
| 205 |
+
audio_list = [int(x) for x in audio_ids]
|
| 206 |
+
|
| 207 |
+
global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])['global_tokens']
|
| 208 |
+
|
| 209 |
+
req = {
|
| 210 |
+
"global_tokens": global_tokens,
|
| 211 |
+
"semantic_tokens": torch.tensor(audio_list).unsqueeze(0).to(self.device),
|
| 212 |
+
}
|
| 213 |
+
out = self.bicodec_detokenizer.detokenize(**req)
|
| 214 |
+
reconstructed_wav = out.detach().cpu().float().squeeze().numpy()
|
| 215 |
+
if reconstructed_wav.size > 0:
|
| 216 |
+
reconstructed_wav -= reconstructed_wav.mean()
|
| 217 |
+
|
| 218 |
+
sf.write(output_path, reconstructed_wav, 16000)
|
| 219 |
+
print(f"Saved VC output to {output_path}")
|
| 220 |
+
return 16000, reconstructed_wav
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def parse_args():
|
| 224 |
+
parser = argparse.ArgumentParser(description="GPA Inference Script")
|
| 225 |
+
|
| 226 |
+
# Paths
|
| 227 |
+
parser.add_argument("--tokenizer_path", type=str, default="/nasdata/model/gpa/glm-4-voice-tokenizer", help="Path to GLM4 tokenizer")
|
| 228 |
+
parser.add_argument("--text_tokenizer_path", type=str, default="/nasdata/model/gpa", help="Path to text tokenizer")
|
| 229 |
+
parser.add_argument("--bicodec_tokenizer_path", type=str, default="/nasdata/model/gpa/BiCodec/", help="Path to BiCodec tokenizer")
|
| 230 |
+
parser.add_argument("--gpa_model_path", type=str, default="/nasdata/model/gpa", help="Path to GPA model")
|
| 231 |
+
|
| 232 |
+
# Audio inputs
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
"--ref_audio_path", type=str, default=None, help="Reference audio path"
|
| 235 |
+
)
|
| 236 |
+
parser.add_argument(
|
| 237 |
+
"--src_audio_path", type=str, default=None, help="Source audio path for VC/STT"
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# Output
|
| 241 |
+
parser.add_argument("--output_dir", type=str, default=".", help="Directory to save output files")
|
| 242 |
+
|
| 243 |
+
# Device
|
| 244 |
+
default_device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 245 |
+
parser.add_argument("--device", type=str, default=default_device, help="Device to use (e.g., cuda:0, cpu)")
|
| 246 |
+
|
| 247 |
+
# Task
|
| 248 |
+
parser.add_argument("--task", type=str, required=True, choices=["stt", "tts-a", "vc"], help="Task to run")
|
| 249 |
+
|
| 250 |
+
# TTS Inputs (Processor Arguments)
|
| 251 |
+
parser.add_argument("--text", type=str, default=None, help="Text for TTS")
|
| 252 |
+
|
| 253 |
+
return parser.parse_args()
|
| 254 |
+
|
| 255 |
+
def main():
|
| 256 |
+
args = parse_args()
|
| 257 |
+
|
| 258 |
+
# Ensure output directory exists
|
| 259 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 260 |
+
|
| 261 |
+
inference = GPAInference(
|
| 262 |
+
tokenizer_path=args.tokenizer_path,
|
| 263 |
+
text_tokenizer_path=args.text_tokenizer_path,
|
| 264 |
+
bicodec_tokenizer_path=args.bicodec_tokenizer_path,
|
| 265 |
+
gpa_model_path=args.gpa_model_path,
|
| 266 |
+
output_dir=args.output_dir,
|
| 267 |
+
device=args.device,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
if args.task == "stt":
|
| 271 |
+
if not args.src_audio_path:
|
| 272 |
+
raise ValueError("Error: --src_audio_path is required for STT task.")
|
| 273 |
+
# Pass gen_kwargs
|
| 274 |
+
result = inference.run_stt(audio_path=args.src_audio_path)
|
| 275 |
+
print("STT Result:", result)
|
| 276 |
+
|
| 277 |
+
elif args.task == "tts-a":
|
| 278 |
+
inference.run_tts(
|
| 279 |
+
task="tts-a",
|
| 280 |
+
output_filename="output_gpa_tts_a.wav",
|
| 281 |
+
text=args.text,
|
| 282 |
+
ref_audio_path=args.ref_audio_path,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
elif args.task == "vc":
|
| 286 |
+
inference.run_vc(
|
| 287 |
+
source_audio_path=args.src_audio_path,
|
| 288 |
+
ref_audio_path=args.ref_audio_path,
|
| 289 |
+
output_filename="output_gpa_vc.wav",
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
if __name__ == "__main__":
|
| 293 |
+
main()
|
models/__init__.py
ADDED
|
File without changes
|
models/bicodec_tokenizer/__init__.py
ADDED
|
File without changes
|
models/bicodec_tokenizer/base_model.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Time :2025/3/29 10:28
|
| 3 |
+
# Author :Hui Huang
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import yaml
|
| 9 |
+
|
| 10 |
+
from .tokenizer_utils import load_config
|
| 11 |
+
import os
|
| 12 |
+
from safetensors.torch import load_file
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SparkBaseModel(nn.Module):
|
| 16 |
+
@classmethod
|
| 17 |
+
def from_pretrained(cls, model_path: str):
|
| 18 |
+
config = load_config(os.path.join(model_path, "config.yaml"))['audio_tokenizer']
|
| 19 |
+
model = cls(config)
|
| 20 |
+
state_dict = load_file(os.path.join(model_path, "model.safetensors"))
|
| 21 |
+
model.load_state_dict(state_dict, strict=False)
|
| 22 |
+
model.eval()
|
| 23 |
+
model.remove_weight_norm()
|
| 24 |
+
return model
|
| 25 |
+
|
| 26 |
+
def remove_weight_norm(self):
|
| 27 |
+
"""Removes weight normalization from all layers."""
|
| 28 |
+
|
| 29 |
+
def _remove_weight_norm(m):
|
| 30 |
+
try:
|
| 31 |
+
torch.nn.utils.remove_weight_norm(m)
|
| 32 |
+
except ValueError:
|
| 33 |
+
pass # The module didn't have weight norm
|
| 34 |
+
|
| 35 |
+
self.apply(_remove_weight_norm)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class SnacBaseModel(nn.Module):
|
| 39 |
+
@classmethod
|
| 40 |
+
def from_config(cls, config_path):
|
| 41 |
+
with open(config_path, "r") as f:
|
| 42 |
+
config = json.load(f)
|
| 43 |
+
model = cls(**config)
|
| 44 |
+
return model
|
| 45 |
+
|
| 46 |
+
@classmethod
|
| 47 |
+
def from_pretrained(cls, model_path: str):
|
| 48 |
+
model = cls.from_config(os.path.join(model_path, "config.json"))
|
| 49 |
+
state_dict = torch.load(
|
| 50 |
+
os.path.join(model_path, "pytorch_model.bin"),
|
| 51 |
+
map_location="cpu", weights_only=True)
|
| 52 |
+
model.load_state_dict(state_dict, strict=False)
|
| 53 |
+
model.eval()
|
| 54 |
+
return model
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class MegaBaseModel(nn.Module):
|
| 58 |
+
CKPT_NAME = "model"
|
| 59 |
+
|
| 60 |
+
@classmethod
|
| 61 |
+
def from_pretrained(cls, model_path: str):
|
| 62 |
+
config_file = None
|
| 63 |
+
ckpt_path = None
|
| 64 |
+
for file in os.listdir(model_path):
|
| 65 |
+
if file.endswith(".ckpt"):
|
| 66 |
+
ckpt_path = os.path.join(model_path, file)
|
| 67 |
+
if file.endswith(".yaml"):
|
| 68 |
+
config_file = os.path.join(model_path, file)
|
| 69 |
+
if ckpt_path is None:
|
| 70 |
+
raise FileNotFoundError(f"No checkpoint found at {model_path}")
|
| 71 |
+
|
| 72 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
| 73 |
+
state_dict_all = {
|
| 74 |
+
k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in checkpoint["state_dict"].items()
|
| 75 |
+
}
|
| 76 |
+
state_dict = state_dict_all[cls.CKPT_NAME]
|
| 77 |
+
state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in state_dict.items()}
|
| 78 |
+
|
| 79 |
+
if config_file is not None:
|
| 80 |
+
with open(config_file) as f:
|
| 81 |
+
config = yaml.safe_load(f)
|
| 82 |
+
model = cls(config)
|
| 83 |
+
else:
|
| 84 |
+
model = cls()
|
| 85 |
+
model.load_state_dict(state_dict, strict=False)
|
| 86 |
+
model.eval()
|
| 87 |
+
return model
|
models/bicodec_tokenizer/batch_processor.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Time :2024/11/17 15:33
|
| 3 |
+
# Author :Hui Huang
|
| 4 |
+
import asyncio
|
| 5 |
+
import uuid
|
| 6 |
+
from typing import Callable, List, Any, Awaitable, Tuple
|
| 7 |
+
from asyncio import Queue
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BatchProcessor:
|
| 11 |
+
"""Batch Processor for handling asynchronous requests in batches.
|
| 12 |
+
|
| 13 |
+
This class manages a queue of requests and processes them in batches
|
| 14 |
+
using multiple worker tasks.
|
| 15 |
+
|
| 16 |
+
Attributes:
|
| 17 |
+
processing_function (Callable[[List[Any]], Awaitable[List[Any]]]):
|
| 18 |
+
The function used for processing requests in batches.
|
| 19 |
+
num_workers (int): The number of worker tasks to process requests.
|
| 20 |
+
batch_size (int): The maximum number of requests to process in a single batch.
|
| 21 |
+
request_queue (Queue): The queue holding incoming requests.
|
| 22 |
+
loop (asyncio.AbstractEventLoop): The event loop used to create worker tasks.
|
| 23 |
+
worker_tasks (List[asyncio.Task]): The list of worker tasks.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
processing_function: Callable[[List[Any]], Awaitable[List[Any]]],
|
| 29 |
+
num_workers: int,
|
| 30 |
+
batch_size: int,
|
| 31 |
+
wait_timeout: float = 0.05
|
| 32 |
+
) -> None:
|
| 33 |
+
"""Initialize the BatchProcessor with the given processing function, number of workers, and batch size.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
processing_function (Callable[[List[Any]], Awaitable[List[Any]]]):
|
| 37 |
+
The function used for processing requests in batches.
|
| 38 |
+
num_workers (int): The number of worker tasks to process requests.
|
| 39 |
+
batch_size (int): The maximum number of requests to process in a single batch.
|
| 40 |
+
"""
|
| 41 |
+
self.processing_function = processing_function
|
| 42 |
+
self.num_workers = num_workers
|
| 43 |
+
self.batch_size = batch_size
|
| 44 |
+
self.wait_timeout = wait_timeout
|
| 45 |
+
self.request_queue: Queue = Queue()
|
| 46 |
+
self.loop = asyncio.get_running_loop()
|
| 47 |
+
self.worker_tasks = [
|
| 48 |
+
self.loop.create_task(self.batch_processor(i)) for i in range(num_workers)
|
| 49 |
+
]
|
| 50 |
+
# Wait until all worker tasks are started
|
| 51 |
+
self.loop.create_task(self._log_workers_started())
|
| 52 |
+
|
| 53 |
+
async def _log_workers_started(self):
|
| 54 |
+
await asyncio.sleep(0) # Yield control to ensure workers have started
|
| 55 |
+
|
| 56 |
+
async def batch_processor(self, worker_id: int):
|
| 57 |
+
"""Worker task that processes requests from the queue in batches.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
worker_id (int): The identifier for the worker task.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
while True:
|
| 64 |
+
requests: List[Tuple[Any, asyncio.Future]] = []
|
| 65 |
+
try:
|
| 66 |
+
while len(requests) < self.batch_size:
|
| 67 |
+
request = await asyncio.wait_for(
|
| 68 |
+
self.request_queue.get(), timeout=self.wait_timeout
|
| 69 |
+
)
|
| 70 |
+
requests.append(request)
|
| 71 |
+
except asyncio.TimeoutError:
|
| 72 |
+
pass
|
| 73 |
+
|
| 74 |
+
if requests:
|
| 75 |
+
all_requests = [
|
| 76 |
+
req[0] for req in requests
|
| 77 |
+
] # Extract the actual input data from each request tuple
|
| 78 |
+
futures = [req[1] for req in requests] # Extract the futures to resolve
|
| 79 |
+
try:
|
| 80 |
+
results = await self.processing_function(all_requests)
|
| 81 |
+
|
| 82 |
+
for (future, result) in zip(futures, results):
|
| 83 |
+
future.set_result(result)
|
| 84 |
+
except Exception as e:
|
| 85 |
+
for future in futures:
|
| 86 |
+
future.set_exception(e)
|
| 87 |
+
|
| 88 |
+
async def add_request(self, single_input: Any):
|
| 89 |
+
"""Add a new request to the queue.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
single_input (Any): The input data for processing.
|
| 93 |
+
"""
|
| 94 |
+
# loop = asyncio.get_running_loop()
|
| 95 |
+
future = self.loop.create_future()
|
| 96 |
+
self.request_queue.put_nowait((single_input, future))
|
| 97 |
+
return future
|
| 98 |
+
|
| 99 |
+
async def shutdown(self):
|
| 100 |
+
"""Shutdown the batch processor by cancelling all worker tasks."""
|
| 101 |
+
for task in self.worker_tasks:
|
| 102 |
+
task.cancel()
|
| 103 |
+
try:
|
| 104 |
+
await task
|
| 105 |
+
except asyncio.CancelledError:
|
| 106 |
+
print("Worker task cancelled.")
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class AsyncBatchEngine:
|
| 110 |
+
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
processing_function: Callable[[List[Any]], Awaitable[List[Any]]],
|
| 114 |
+
batch_size: int = 32,
|
| 115 |
+
wait_timeout: float = 0.01,
|
| 116 |
+
):
|
| 117 |
+
"""
|
| 118 |
+
Initialize the AsyncBatchEngine with a processing function, number of workers, and batch size.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
processing_function (Callable[[List[Any]], Awaitable[List[Any]]]): The batch processing function.
|
| 122 |
+
batch_size (int): The maximum number of requests to process in a single batch.
|
| 123 |
+
"""
|
| 124 |
+
self._processing_function = processing_function
|
| 125 |
+
self._batch_size = batch_size
|
| 126 |
+
self._is_running = False
|
| 127 |
+
self._batch_processor = None
|
| 128 |
+
self._wait_timeout = wait_timeout
|
| 129 |
+
|
| 130 |
+
async def start(self):
|
| 131 |
+
"""Start the engine by initializing the batch processor and worker tasks."""
|
| 132 |
+
if self._is_running:
|
| 133 |
+
return
|
| 134 |
+
|
| 135 |
+
self._batch_processor = BatchProcessor(
|
| 136 |
+
processing_function=self._processing_function,
|
| 137 |
+
batch_size=self._batch_size,
|
| 138 |
+
wait_timeout=self._wait_timeout,
|
| 139 |
+
num_workers=1
|
| 140 |
+
)
|
| 141 |
+
self._is_running = True
|
| 142 |
+
|
| 143 |
+
async def stop(self):
|
| 144 |
+
"""Stop the engine by shutting down the batch processor and worker tasks."""
|
| 145 |
+
self._check_running()
|
| 146 |
+
self._is_running = False
|
| 147 |
+
if self._batch_processor is not None:
|
| 148 |
+
await self._batch_processor.shutdown()
|
| 149 |
+
|
| 150 |
+
def _check_running(self):
|
| 151 |
+
"""Check if the engine is running.
|
| 152 |
+
|
| 153 |
+
Raises:
|
| 154 |
+
ValueError: If the engine is not running.
|
| 155 |
+
"""
|
| 156 |
+
if not self._is_running:
|
| 157 |
+
raise ValueError(
|
| 158 |
+
"The engine is not running. "
|
| 159 |
+
"You must start the engine before using it."
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
async def add_request(self, single_input: Any, request_id: str = None) -> dict:
|
| 163 |
+
"""Asynchronously add a request to be processed.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
single_input (Any): The input data for processing.
|
| 167 |
+
request_id (str): Optional request identifier to avoid data mix-up.
|
| 168 |
+
|
| 169 |
+
Raises:
|
| 170 |
+
ValueError: If the engine is not running when this method is called.
|
| 171 |
+
"""
|
| 172 |
+
if not self._is_running:
|
| 173 |
+
await self.start()
|
| 174 |
+
|
| 175 |
+
if request_id is None:
|
| 176 |
+
request_id = str(uuid.uuid4()) # Assign a unique ID if not provided
|
| 177 |
+
future = await self._batch_processor.add_request(single_input=single_input) # type: ignore
|
| 178 |
+
result = await future
|
| 179 |
+
return dict(
|
| 180 |
+
request_id=request_id,
|
| 181 |
+
feature=result
|
| 182 |
+
)
|
models/bicodec_tokenizer/models/__init__.py
ADDED
|
File without changes
|
models/bicodec_tokenizer/models/audio_tokenizer.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SparkAudio
|
| 2 |
+
# 2025 Xinsheng Wang (w.xinshawn@gmail.com)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import sys
|
| 17 |
+
sys.path.append("../..")
|
| 18 |
+
import torch
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import Any, Dict, Tuple
|
| 23 |
+
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
|
| 24 |
+
|
| 25 |
+
from arktts.models.sparktts.utils.file import load_config
|
| 26 |
+
from arktts.models.sparktts.utils.audio import load_audio
|
| 27 |
+
from arktts.models.sparktts.models.bicodec import BiCodec
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class BiCodecTokenizer:
|
| 31 |
+
"""BiCodec tokenizer for handling audio input and tokenization."""
|
| 32 |
+
|
| 33 |
+
def __init__(self, model_dir: Path, device: torch.device = None, **kwargs):
|
| 34 |
+
super().__init__()
|
| 35 |
+
"""
|
| 36 |
+
Args:
|
| 37 |
+
model_dir: Path to the model directory.
|
| 38 |
+
device: Device to run the model on (default is GPU if available).
|
| 39 |
+
"""
|
| 40 |
+
self.device = device
|
| 41 |
+
self.model_dir = model_dir
|
| 42 |
+
self.config = load_config(f"{model_dir}/config.yaml")
|
| 43 |
+
self._initialize_model()
|
| 44 |
+
|
| 45 |
+
def _initialize_model(self):
|
| 46 |
+
"""Load and initialize the BiCodec model and Wav2Vec2 feature extractor."""
|
| 47 |
+
self.model = BiCodec.load_from_checkpoint(f"{self.model_dir}/BiCodec").to(
|
| 48 |
+
self.device
|
| 49 |
+
)
|
| 50 |
+
self.processor = Wav2Vec2FeatureExtractor.from_pretrained(
|
| 51 |
+
f"{self.model_dir}/wav2vec2-large-xlsr-53"
|
| 52 |
+
)
|
| 53 |
+
self.feature_extractor = Wav2Vec2Model.from_pretrained(
|
| 54 |
+
f"{self.model_dir}/wav2vec2-large-xlsr-53"
|
| 55 |
+
).to(self.device)
|
| 56 |
+
self.feature_extractor.config.output_hidden_states = True
|
| 57 |
+
|
| 58 |
+
def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
|
| 59 |
+
"""Get reference audio clip for speaker embedding."""
|
| 60 |
+
ref_segment_length = (
|
| 61 |
+
int(self.config["sample_rate"] * self.config["ref_segment_duration"])
|
| 62 |
+
// self.config["latent_hop_length"]
|
| 63 |
+
* self.config["latent_hop_length"]
|
| 64 |
+
)
|
| 65 |
+
wav_length = len(wav)
|
| 66 |
+
|
| 67 |
+
if ref_segment_length > wav_length:
|
| 68 |
+
# Repeat and truncate to handle insufficient length
|
| 69 |
+
wav = np.tile(wav, ref_segment_length // wav_length + 1)
|
| 70 |
+
|
| 71 |
+
return wav[:ref_segment_length]
|
| 72 |
+
|
| 73 |
+
def process_audio(self, wav_path: Path) -> Tuple[np.ndarray, torch.Tensor]:
|
| 74 |
+
"""load auido and get reference audio from wav path"""
|
| 75 |
+
wav = load_audio(
|
| 76 |
+
wav_path,
|
| 77 |
+
sampling_rate=self.config["sample_rate"],
|
| 78 |
+
volume_normalize=self.config["volume_normalize"],
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
wav_ref = self.get_ref_clip(wav)
|
| 82 |
+
|
| 83 |
+
wav_ref = torch.from_numpy(wav_ref).unsqueeze(0).float()
|
| 84 |
+
return wav, wav_ref
|
| 85 |
+
|
| 86 |
+
def extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor:
|
| 87 |
+
"""extract wav2vec2 features"""
|
| 88 |
+
inputs = self.processor(
|
| 89 |
+
wavs,
|
| 90 |
+
sampling_rate=16000,
|
| 91 |
+
return_tensors="pt",
|
| 92 |
+
padding=True,
|
| 93 |
+
output_hidden_states=True,
|
| 94 |
+
).input_values
|
| 95 |
+
feat = self.feature_extractor(inputs.to(self.feature_extractor.device))
|
| 96 |
+
feats_mix = (
|
| 97 |
+
feat.hidden_states[11] + feat.hidden_states[14] + feat.hidden_states[16]
|
| 98 |
+
) / 3
|
| 99 |
+
|
| 100 |
+
return feats_mix
|
| 101 |
+
|
| 102 |
+
def tokenize_batch(self, batch: Dict[str, Any]) -> torch.Tensor:
|
| 103 |
+
"""tokenize the batch of audio
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
batch:
|
| 107 |
+
wavs (List[np.ndarray]): batch of audio
|
| 108 |
+
ref_wavs (torch.Tensor): reference audio. shape: (batch_size, seq_len)
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
semantic_tokens: semantic tokens. shape: (batch_size, seq_len, latent_dim)
|
| 112 |
+
global_tokens: global tokens. shape: (batch_size, seq_len, global_dim)
|
| 113 |
+
"""
|
| 114 |
+
feats = self.extract_wav2vec2_features(batch["wav"])
|
| 115 |
+
batch["feat"] = feats
|
| 116 |
+
semantic_tokens, global_tokens = self.model.tokenize(batch)
|
| 117 |
+
|
| 118 |
+
return global_tokens, semantic_tokens
|
| 119 |
+
|
| 120 |
+
def tokenize(self, audio_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 121 |
+
"""tokenize the audio"""
|
| 122 |
+
wav, ref_wav = self.process_audio(audio_path)
|
| 123 |
+
feat = self.extract_wav2vec2_features(wav)
|
| 124 |
+
batch = {
|
| 125 |
+
"wav": torch.from_numpy(wav).unsqueeze(0).float().to(self.device),
|
| 126 |
+
"ref_wav": ref_wav.to(self.device),
|
| 127 |
+
"feat": feat.to(self.device),
|
| 128 |
+
}
|
| 129 |
+
semantic_tokens, global_tokens = self.model.tokenize(batch)
|
| 130 |
+
|
| 131 |
+
return global_tokens, semantic_tokens
|
| 132 |
+
|
| 133 |
+
def detokenize(
|
| 134 |
+
self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor
|
| 135 |
+
) -> np.array:
|
| 136 |
+
"""detokenize the tokens to waveform
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
global_tokens: global tokens. shape: (batch_size, global_dim)
|
| 140 |
+
semantic_tokens: semantic tokens. shape: (batch_size, latent_dim)
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
wav_rec: waveform. shape: (batch_size, seq_len) for batch or (seq_len,) for single
|
| 144 |
+
"""
|
| 145 |
+
global_tokens = global_tokens.unsqueeze(1)
|
| 146 |
+
wav_rec = self.model.detokenize(semantic_tokens, global_tokens)
|
| 147 |
+
return wav_rec.detach().squeeze().cpu().numpy()
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# test
|
| 151 |
+
if __name__ == "__main__":
|
| 152 |
+
import soundfile as sf
|
| 153 |
+
|
| 154 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 155 |
+
tokenizer = BiCodecTokenizer(
|
| 156 |
+
model_dir="pretrained_models/Spark-TTS-0.5B",
|
| 157 |
+
device=device,
|
| 158 |
+
)
|
| 159 |
+
wav_path = "example/prompt_audio.wav"
|
| 160 |
+
|
| 161 |
+
global_tokens, semantic_tokens = tokenizer.tokenize(wav_path)
|
| 162 |
+
|
| 163 |
+
wav_rec = tokenizer.detokenize(global_tokens.squeeze(0), semantic_tokens)
|
| 164 |
+
sf.write("example/prompt_recon.wav", wav_rec, 16000)
|
models/bicodec_tokenizer/models/bicodec.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SparkAudio
|
| 2 |
+
# 2025 Xinsheng Wang (w.xinshawn@gmail.com)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
import sys
|
| 16 |
+
sys.path.append("../..")
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Dict, Any
|
| 21 |
+
from omegaconf import DictConfig
|
| 22 |
+
from safetensors.torch import load_file
|
| 23 |
+
|
| 24 |
+
from ..utils.file import load_config
|
| 25 |
+
from ..modules.speaker.speaker_encoder import SpeakerEncoder
|
| 26 |
+
from ..modules.encoder_decoder.feat_encoder import Encoder
|
| 27 |
+
from ..modules.encoder_decoder.feat_decoder import Decoder
|
| 28 |
+
from ..modules.encoder_decoder.wave_generator import WaveGenerator
|
| 29 |
+
from ..modules.vq.factorized_vector_quantize import FactorizedVectorQuantize
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class BiCodec(nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
BiCodec model for speech synthesis, incorporating a speaker encoder, feature encoder/decoder,
|
| 35 |
+
quantizer, and wave generator.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
mel_params: Dict[str, Any],
|
| 41 |
+
encoder: nn.Module,
|
| 42 |
+
decoder: nn.Module,
|
| 43 |
+
quantizer: nn.Module,
|
| 44 |
+
speaker_encoder: nn.Module,
|
| 45 |
+
prenet: nn.Module,
|
| 46 |
+
postnet: nn.Module,
|
| 47 |
+
**kwargs
|
| 48 |
+
) -> None:
|
| 49 |
+
"""
|
| 50 |
+
Initializes the BiCodec model with the required components.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
mel_params (dict): Parameters for the mel-spectrogram transformer.
|
| 54 |
+
encoder (nn.Module): Encoder module.
|
| 55 |
+
decoder (nn.Module): Decoder module.
|
| 56 |
+
quantizer (nn.Module): Quantizer module.
|
| 57 |
+
speaker_encoder (nn.Module): Speaker encoder module.
|
| 58 |
+
prenet (nn.Module): Prenet network.
|
| 59 |
+
postnet (nn.Module): Postnet network.
|
| 60 |
+
"""
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.encoder = encoder
|
| 63 |
+
self.decoder = decoder
|
| 64 |
+
self.quantizer = quantizer
|
| 65 |
+
self.speaker_encoder = speaker_encoder
|
| 66 |
+
self.prenet = prenet
|
| 67 |
+
self.postnet = postnet
|
| 68 |
+
self.init_mel_transformer(mel_params)
|
| 69 |
+
|
| 70 |
+
@classmethod
|
| 71 |
+
def load_from_checkpoint(cls, model_dir: Path, **kwargs) -> "BiCodec":
|
| 72 |
+
"""
|
| 73 |
+
Loads the model from a checkpoint.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
model_dir (Path): Path to the model directory containing checkpoint and config.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
BiCodec: The initialized BiCodec model.
|
| 80 |
+
"""
|
| 81 |
+
ckpt_path = f'{model_dir}/model.safetensors'
|
| 82 |
+
config = load_config(f'{model_dir}/config.yaml')['audio_tokenizer']
|
| 83 |
+
mel_params = config["mel_params"]
|
| 84 |
+
encoder = Encoder(**config["encoder"])
|
| 85 |
+
quantizer = FactorizedVectorQuantize(**config["quantizer"])
|
| 86 |
+
prenet = Decoder(**config["prenet"])
|
| 87 |
+
postnet = Decoder(**config["postnet"])
|
| 88 |
+
decoder = WaveGenerator(**config["decoder"])
|
| 89 |
+
speaker_encoder = SpeakerEncoder(**config["speaker_encoder"])
|
| 90 |
+
|
| 91 |
+
model = cls(
|
| 92 |
+
mel_params=mel_params,
|
| 93 |
+
encoder=encoder,
|
| 94 |
+
decoder=decoder,
|
| 95 |
+
quantizer=quantizer,
|
| 96 |
+
speaker_encoder=speaker_encoder,
|
| 97 |
+
prenet=prenet,
|
| 98 |
+
postnet=postnet,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
state_dict = load_file(ckpt_path)
|
| 102 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
| 103 |
+
|
| 104 |
+
for key in missing_keys:
|
| 105 |
+
print(f"Missing tensor: {key}")
|
| 106 |
+
for key in unexpected_keys:
|
| 107 |
+
print(f"Unexpected tensor: {key}")
|
| 108 |
+
|
| 109 |
+
model.eval()
|
| 110 |
+
model.remove_weight_norm()
|
| 111 |
+
|
| 112 |
+
return model
|
| 113 |
+
|
| 114 |
+
def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
|
| 115 |
+
"""
|
| 116 |
+
Performs a forward pass through the model.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
batch (dict): A dictionary containing features, reference waveform, and target waveform.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
dict: A dictionary containing the reconstruction, features, and other metrics.
|
| 123 |
+
"""
|
| 124 |
+
feat = batch["feat"]
|
| 125 |
+
mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)
|
| 126 |
+
|
| 127 |
+
z = self.encoder(feat.transpose(1, 2))
|
| 128 |
+
vq_outputs = self.quantizer(z)
|
| 129 |
+
|
| 130 |
+
x_vector, d_vector = self.speaker_encoder(mel.transpose(1, 2))
|
| 131 |
+
|
| 132 |
+
conditions = d_vector
|
| 133 |
+
with_speaker_loss = False
|
| 134 |
+
|
| 135 |
+
x = self.prenet(vq_outputs["z_q"], conditions)
|
| 136 |
+
pred_feat = self.postnet(x)
|
| 137 |
+
x = x + conditions.unsqueeze(-1)
|
| 138 |
+
wav_recon = self.decoder(x)
|
| 139 |
+
|
| 140 |
+
return {
|
| 141 |
+
"vq_loss": vq_outputs["vq_loss"],
|
| 142 |
+
"perplexity": vq_outputs["perplexity"],
|
| 143 |
+
"cluster_size": vq_outputs["active_num"],
|
| 144 |
+
"recons": wav_recon,
|
| 145 |
+
"pred_feat": pred_feat,
|
| 146 |
+
"x_vector": x_vector,
|
| 147 |
+
"d_vector": d_vector,
|
| 148 |
+
"audios": batch["wav"].unsqueeze(1),
|
| 149 |
+
"with_speaker_loss": with_speaker_loss,
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
@torch.no_grad()
|
| 153 |
+
def tokenize(self, batch: Dict[str, Any]):
|
| 154 |
+
"""
|
| 155 |
+
Tokenizes the input audio into semantic and global tokens.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
batch (dict): The input audio features and reference waveform.
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
tuple: Semantic tokens and global tokens.
|
| 162 |
+
"""
|
| 163 |
+
feat = batch["feat"]
|
| 164 |
+
mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)
|
| 165 |
+
|
| 166 |
+
z = self.encoder(feat.transpose(1, 2))
|
| 167 |
+
semantic_tokens = self.quantizer.tokenize(z)
|
| 168 |
+
global_tokens = self.speaker_encoder.tokenize(mel.transpose(1, 2))
|
| 169 |
+
|
| 170 |
+
return semantic_tokens, global_tokens
|
| 171 |
+
|
| 172 |
+
@torch.no_grad()
|
| 173 |
+
def detokenize(self, semantic_tokens, global_tokens):
|
| 174 |
+
"""
|
| 175 |
+
Detokenizes the semantic and global tokens into a waveform.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
semantic_tokens (tensor): Semantic tokens.
|
| 179 |
+
global_tokens (tensor): Global tokens.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
tensor: Reconstructed waveform.
|
| 183 |
+
"""
|
| 184 |
+
z_q = self.quantizer.detokenize(semantic_tokens)
|
| 185 |
+
d_vector = self.speaker_encoder.detokenize(global_tokens)
|
| 186 |
+
x = self.prenet(z_q, d_vector)
|
| 187 |
+
x = x + d_vector.unsqueeze(-1)
|
| 188 |
+
wav_recon = self.decoder(x)
|
| 189 |
+
|
| 190 |
+
return wav_recon
|
| 191 |
+
|
| 192 |
+
def init_mel_transformer(self, config: Dict[str, Any]):
|
| 193 |
+
"""
|
| 194 |
+
Initializes the MelSpectrogram transformer based on the provided configuration.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
config (dict): Configuration parameters for MelSpectrogram.
|
| 198 |
+
"""
|
| 199 |
+
import torchaudio.transforms as TT
|
| 200 |
+
|
| 201 |
+
self.mel_transformer = TT.MelSpectrogram(
|
| 202 |
+
config["sample_rate"],
|
| 203 |
+
config["n_fft"],
|
| 204 |
+
config["win_length"],
|
| 205 |
+
config["hop_length"],
|
| 206 |
+
config["mel_fmin"],
|
| 207 |
+
config["mel_fmax"],
|
| 208 |
+
n_mels=config["num_mels"],
|
| 209 |
+
power=1,
|
| 210 |
+
norm="slaney",
|
| 211 |
+
mel_scale="slaney",
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
def remove_weight_norm(self):
|
| 215 |
+
"""Removes weight normalization from all layers."""
|
| 216 |
+
def _remove_weight_norm(m):
|
| 217 |
+
try:
|
| 218 |
+
torch.nn.utils.remove_weight_norm(m)
|
| 219 |
+
except ValueError:
|
| 220 |
+
pass # The module didn't have weight norm
|
| 221 |
+
|
| 222 |
+
self.apply(_remove_weight_norm)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
# Test the model
|
| 226 |
+
if __name__ == "__main__":
|
| 227 |
+
|
| 228 |
+
config = load_config("pretrained_models/SparkTTS-0.5B/BiCodec/config.yaml")
|
| 229 |
+
model = BiCodec.load_from_checkpoint(
|
| 230 |
+
model_dir="pretrained_models/SparkTTS-0.5B/BiCodec",
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# Generate random inputs for testing
|
| 234 |
+
duration = 0.96
|
| 235 |
+
x = torch.randn(20, 1, int(duration * 16000))
|
| 236 |
+
feat = torch.randn(20, int(duration * 50), 1024)
|
| 237 |
+
inputs = {"feat": feat, "wav": x, "ref_wav": x}
|
| 238 |
+
|
| 239 |
+
# Forward pass
|
| 240 |
+
outputs = model(inputs)
|
| 241 |
+
semantic_tokens, global_tokens = model.tokenize(inputs)
|
| 242 |
+
wav_recon = model.detokenize(semantic_tokens, global_tokens)
|
| 243 |
+
|
| 244 |
+
# Verify if the reconstruction matches
|
| 245 |
+
if torch.allclose(outputs["recons"].detach(), wav_recon):
|
| 246 |
+
print("Test successful")
|
| 247 |
+
else:
|
| 248 |
+
print("Test failed")
|
models/bicodec_tokenizer/modules/blocks/layers.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SparkAudio
|
| 2 |
+
# 2025 Xinsheng Wang (w.xinshawn@gmail.com)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
# Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
from torch.nn.utils import weight_norm
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def WNConv1d(*args, **kwargs):
|
| 25 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def WNConvTranspose1d(*args, **kwargs):
|
| 29 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Scripting this brings model speed up 1.4x
|
| 33 |
+
@torch.jit.script
|
| 34 |
+
def snake(x, alpha):
|
| 35 |
+
shape = x.shape
|
| 36 |
+
x = x.reshape(shape[0], shape[1], -1)
|
| 37 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
| 38 |
+
x = x.reshape(shape)
|
| 39 |
+
return x
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Snake1d(nn.Module):
|
| 43 |
+
def __init__(self, channels):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
return snake(x, self.alpha)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class ResidualUnit(nn.Module):
|
| 52 |
+
def __init__(self, dim: int = 16, dilation: int = 1):
|
| 53 |
+
super().__init__()
|
| 54 |
+
pad = ((7 - 1) * dilation) // 2
|
| 55 |
+
self.block = nn.Sequential(
|
| 56 |
+
Snake1d(dim),
|
| 57 |
+
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
| 58 |
+
Snake1d(dim),
|
| 59 |
+
WNConv1d(dim, dim, kernel_size=1),
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
y = self.block(x)
|
| 64 |
+
pad = (x.shape[-1] - y.shape[-1]) // 2
|
| 65 |
+
if pad > 0:
|
| 66 |
+
x = x[..., pad:-pad]
|
| 67 |
+
return x + y
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def init_weights(m):
|
| 71 |
+
if isinstance(m, nn.Conv1d):
|
| 72 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 73 |
+
nn.init.constant_(m.bias, 0)
|
models/bicodec_tokenizer/modules/blocks/samper.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SparkAudio
|
| 2 |
+
# 2025 Xinsheng Wang (w.xinshawn@gmail.com)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class SamplingBlock(nn.Module):
|
| 23 |
+
"""Sampling block for upsampling or downsampling"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
dim: int,
|
| 28 |
+
groups: int = 1,
|
| 29 |
+
upsample_scale: int = 1,
|
| 30 |
+
downsample_scale: int = 1,
|
| 31 |
+
) -> None:
|
| 32 |
+
"""
|
| 33 |
+
Args:
|
| 34 |
+
dim: input dimension
|
| 35 |
+
groups: number of groups
|
| 36 |
+
upsample_scale: upsampling scale
|
| 37 |
+
downsample_scale: downsampling scale
|
| 38 |
+
"""
|
| 39 |
+
super(SamplingBlock, self).__init__()
|
| 40 |
+
|
| 41 |
+
self.upsample_scale = upsample_scale
|
| 42 |
+
self.downsample_scale = downsample_scale
|
| 43 |
+
|
| 44 |
+
if self.upsample_scale > 1:
|
| 45 |
+
self.de_conv_upsampler = nn.Sequential(
|
| 46 |
+
nn.LeakyReLU(0.2),
|
| 47 |
+
nn.ConvTranspose1d(
|
| 48 |
+
dim,
|
| 49 |
+
dim,
|
| 50 |
+
kernel_size=upsample_scale * 2,
|
| 51 |
+
stride=upsample_scale,
|
| 52 |
+
padding=upsample_scale // 2 + upsample_scale % 2,
|
| 53 |
+
output_padding=upsample_scale % 2,
|
| 54 |
+
groups=groups,
|
| 55 |
+
),
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
if self.downsample_scale > 1:
|
| 59 |
+
self.conv_downsampler = nn.Sequential(
|
| 60 |
+
nn.LeakyReLU(0.2),
|
| 61 |
+
nn.Conv1d(
|
| 62 |
+
dim,
|
| 63 |
+
dim,
|
| 64 |
+
kernel_size=2 * downsample_scale,
|
| 65 |
+
stride=downsample_scale,
|
| 66 |
+
padding=downsample_scale // 2 + downsample_scale % 2,
|
| 67 |
+
groups=groups,
|
| 68 |
+
),
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
@staticmethod
|
| 72 |
+
def repeat_upsampler(x, upsample_scale):
|
| 73 |
+
return x.repeat_interleave(upsample_scale, dim=2)
|
| 74 |
+
|
| 75 |
+
@staticmethod
|
| 76 |
+
def skip_downsampler(x, downsample_scale):
|
| 77 |
+
return F.avg_pool1d(x, kernel_size=downsample_scale, stride=downsample_scale)
|
| 78 |
+
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
x = x.transpose(1, 2)
|
| 81 |
+
if self.upsample_scale > 1:
|
| 82 |
+
repeat_res = self.repeat_upsampler(x, self.upsample_scale)
|
| 83 |
+
deconv_res = self.de_conv_upsampler(x)
|
| 84 |
+
upmerge_res = repeat_res + deconv_res
|
| 85 |
+
else:
|
| 86 |
+
upmerge_res = x
|
| 87 |
+
repeat_res = x
|
| 88 |
+
|
| 89 |
+
if self.downsample_scale > 1:
|
| 90 |
+
conv_res = self.conv_downsampler(upmerge_res)
|
| 91 |
+
skip2_res = self.skip_downsampler(upmerge_res, self.downsample_scale)
|
| 92 |
+
skip1_res = self.skip_downsampler(repeat_res, self.downsample_scale)
|
| 93 |
+
else:
|
| 94 |
+
conv_res = upmerge_res
|
| 95 |
+
skip2_res = upmerge_res
|
| 96 |
+
skip1_res = repeat_res
|
| 97 |
+
|
| 98 |
+
final_res = conv_res + skip1_res + skip2_res
|
| 99 |
+
|
| 100 |
+
return final_res
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# test
|
| 104 |
+
if __name__ == "__main__":
|
| 105 |
+
test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
|
| 106 |
+
model = SamplingBlock(1024, 1024, upsample_scale=2)
|
| 107 |
+
model_down = SamplingBlock(1024, 1024, downsample_scale=2)
|
| 108 |
+
output = model(test_input)
|
| 109 |
+
output_down = model_down(test_input)
|
| 110 |
+
print("shape after upsample * 2", output.shape) # torch.Size([8, 1024, 100])
|
| 111 |
+
print("shape after downsample * 2", output_down.shape) # torch.Size([8, 1024, 25])
|
| 112 |
+
if output.shape == torch.Size([8, 1024, 100]) and output_down.shape == torch.Size(
|
| 113 |
+
[8, 1024, 25]
|
| 114 |
+
):
|
| 115 |
+
print("test successful")
|
models/bicodec_tokenizer/modules/blocks/vocos.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SparkAudio
|
| 2 |
+
# 2025 Xinsheng Wang (w.xinshawn@gmail.com)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
|
| 20 |
+
from typing import Tuple
|
| 21 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
| 22 |
+
|
| 23 |
+
from typing import Optional
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ConvNeXtBlock(nn.Module):
|
| 27 |
+
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
dim (int): Number of input channels.
|
| 31 |
+
intermediate_dim (int): Dimensionality of the intermediate layer.
|
| 32 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
| 33 |
+
Defaults to None.
|
| 34 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
| 35 |
+
None means non-conditional LayerNorm. Defaults to None.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
dim: int,
|
| 41 |
+
intermediate_dim: int,
|
| 42 |
+
layer_scale_init_value: float,
|
| 43 |
+
condition_dim: Optional[int] = None,
|
| 44 |
+
):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.dwconv = nn.Conv1d(
|
| 47 |
+
dim, dim, kernel_size=7, padding=3, groups=dim
|
| 48 |
+
) # depthwise conv
|
| 49 |
+
self.adanorm = condition_dim is not None
|
| 50 |
+
if condition_dim:
|
| 51 |
+
self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
|
| 52 |
+
else:
|
| 53 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 54 |
+
self.pwconv1 = nn.Linear(
|
| 55 |
+
dim, intermediate_dim
|
| 56 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
| 57 |
+
self.act = nn.GELU()
|
| 58 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
| 59 |
+
self.gamma = (
|
| 60 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
| 61 |
+
if layer_scale_init_value > 0
|
| 62 |
+
else None
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def forward(
|
| 66 |
+
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
|
| 67 |
+
) -> torch.Tensor:
|
| 68 |
+
residual = x
|
| 69 |
+
x = self.dwconv(x)
|
| 70 |
+
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
| 71 |
+
if self.adanorm:
|
| 72 |
+
assert cond_embedding_id is not None
|
| 73 |
+
x = self.norm(x, cond_embedding_id)
|
| 74 |
+
else:
|
| 75 |
+
x = self.norm(x)
|
| 76 |
+
x = self.pwconv1(x)
|
| 77 |
+
x = self.act(x)
|
| 78 |
+
x = self.pwconv2(x)
|
| 79 |
+
if self.gamma is not None:
|
| 80 |
+
x = self.gamma * x
|
| 81 |
+
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
| 82 |
+
|
| 83 |
+
x = residual + x
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class AdaLayerNorm(nn.Module):
|
| 88 |
+
"""
|
| 89 |
+
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
condition_dim (int): Dimension of the condition.
|
| 93 |
+
embedding_dim (int): Dimension of the embeddings.
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
def __init__(self, condition_dim: int, embedding_dim: int, eps: float = 1e-6):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.eps = eps
|
| 99 |
+
self.dim = embedding_dim
|
| 100 |
+
self.scale = nn.Linear(condition_dim, embedding_dim)
|
| 101 |
+
self.shift = nn.Linear(condition_dim, embedding_dim)
|
| 102 |
+
torch.nn.init.ones_(self.scale.weight)
|
| 103 |
+
torch.nn.init.zeros_(self.shift.weight)
|
| 104 |
+
|
| 105 |
+
def forward(self, x: torch.Tensor, cond_embedding: torch.Tensor) -> torch.Tensor:
|
| 106 |
+
scale = self.scale(cond_embedding)
|
| 107 |
+
shift = self.shift(cond_embedding)
|
| 108 |
+
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
|
| 109 |
+
x = x * scale.unsqueeze(1) + shift.unsqueeze(1)
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class ResBlock1(nn.Module):
|
| 114 |
+
"""
|
| 115 |
+
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
|
| 116 |
+
but without upsampling layers.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
dim (int): Number of input channels.
|
| 120 |
+
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
|
| 121 |
+
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
|
| 122 |
+
Defaults to (1, 3, 5).
|
| 123 |
+
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
|
| 124 |
+
Defaults to 0.1.
|
| 125 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
| 126 |
+
Defaults to None.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
dim: int,
|
| 132 |
+
kernel_size: int = 3,
|
| 133 |
+
dilation: Tuple[int, int, int] = (1, 3, 5),
|
| 134 |
+
lrelu_slope: float = 0.1,
|
| 135 |
+
layer_scale_init_value: Optional[float] = None,
|
| 136 |
+
):
|
| 137 |
+
super().__init__()
|
| 138 |
+
self.lrelu_slope = lrelu_slope
|
| 139 |
+
self.convs1 = nn.ModuleList(
|
| 140 |
+
[
|
| 141 |
+
weight_norm(
|
| 142 |
+
nn.Conv1d(
|
| 143 |
+
dim,
|
| 144 |
+
dim,
|
| 145 |
+
kernel_size,
|
| 146 |
+
1,
|
| 147 |
+
dilation=dilation[0],
|
| 148 |
+
padding=self.get_padding(kernel_size, dilation[0]),
|
| 149 |
+
)
|
| 150 |
+
),
|
| 151 |
+
weight_norm(
|
| 152 |
+
nn.Conv1d(
|
| 153 |
+
dim,
|
| 154 |
+
dim,
|
| 155 |
+
kernel_size,
|
| 156 |
+
1,
|
| 157 |
+
dilation=dilation[1],
|
| 158 |
+
padding=self.get_padding(kernel_size, dilation[1]),
|
| 159 |
+
)
|
| 160 |
+
),
|
| 161 |
+
weight_norm(
|
| 162 |
+
nn.Conv1d(
|
| 163 |
+
dim,
|
| 164 |
+
dim,
|
| 165 |
+
kernel_size,
|
| 166 |
+
1,
|
| 167 |
+
dilation=dilation[2],
|
| 168 |
+
padding=self.get_padding(kernel_size, dilation[2]),
|
| 169 |
+
)
|
| 170 |
+
),
|
| 171 |
+
]
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
self.convs2 = nn.ModuleList(
|
| 175 |
+
[
|
| 176 |
+
weight_norm(
|
| 177 |
+
nn.Conv1d(
|
| 178 |
+
dim,
|
| 179 |
+
dim,
|
| 180 |
+
kernel_size,
|
| 181 |
+
1,
|
| 182 |
+
dilation=1,
|
| 183 |
+
padding=self.get_padding(kernel_size, 1),
|
| 184 |
+
)
|
| 185 |
+
),
|
| 186 |
+
weight_norm(
|
| 187 |
+
nn.Conv1d(
|
| 188 |
+
dim,
|
| 189 |
+
dim,
|
| 190 |
+
kernel_size,
|
| 191 |
+
1,
|
| 192 |
+
dilation=1,
|
| 193 |
+
padding=self.get_padding(kernel_size, 1),
|
| 194 |
+
)
|
| 195 |
+
),
|
| 196 |
+
weight_norm(
|
| 197 |
+
nn.Conv1d(
|
| 198 |
+
dim,
|
| 199 |
+
dim,
|
| 200 |
+
kernel_size,
|
| 201 |
+
1,
|
| 202 |
+
dilation=1,
|
| 203 |
+
padding=self.get_padding(kernel_size, 1),
|
| 204 |
+
)
|
| 205 |
+
),
|
| 206 |
+
]
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
self.gamma = nn.ParameterList(
|
| 210 |
+
[
|
| 211 |
+
(
|
| 212 |
+
nn.Parameter(
|
| 213 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
| 214 |
+
)
|
| 215 |
+
if layer_scale_init_value is not None
|
| 216 |
+
else None
|
| 217 |
+
),
|
| 218 |
+
(
|
| 219 |
+
nn.Parameter(
|
| 220 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
| 221 |
+
)
|
| 222 |
+
if layer_scale_init_value is not None
|
| 223 |
+
else None
|
| 224 |
+
),
|
| 225 |
+
(
|
| 226 |
+
nn.Parameter(
|
| 227 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
| 228 |
+
)
|
| 229 |
+
if layer_scale_init_value is not None
|
| 230 |
+
else None
|
| 231 |
+
),
|
| 232 |
+
]
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 236 |
+
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
|
| 237 |
+
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
|
| 238 |
+
xt = c1(xt)
|
| 239 |
+
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
|
| 240 |
+
xt = c2(xt)
|
| 241 |
+
if gamma is not None:
|
| 242 |
+
xt = gamma * xt
|
| 243 |
+
x = xt + x
|
| 244 |
+
return x
|
| 245 |
+
|
| 246 |
+
def remove_weight_norm(self):
|
| 247 |
+
for l in self.convs1:
|
| 248 |
+
remove_weight_norm(l)
|
| 249 |
+
for l in self.convs2:
|
| 250 |
+
remove_weight_norm(l)
|
| 251 |
+
|
| 252 |
+
@staticmethod
|
| 253 |
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
| 254 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class Backbone(nn.Module):
|
| 258 |
+
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
|
| 259 |
+
|
| 260 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 261 |
+
"""
|
| 262 |
+
Args:
|
| 263 |
+
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
|
| 264 |
+
C denotes output features, and L is the sequence length.
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
|
| 268 |
+
and H denotes the model dimension.
|
| 269 |
+
"""
|
| 270 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class VocosBackbone(Backbone):
|
| 274 |
+
"""
|
| 275 |
+
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
input_channels (int): Number of input features channels.
|
| 279 |
+
dim (int): Hidden dimension of the model.
|
| 280 |
+
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
|
| 281 |
+
num_layers (int): Number of ConvNeXtBlock layers.
|
| 282 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
|
| 283 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
| 284 |
+
None means non-conditional model. Defaults to None.
|
| 285 |
+
"""
|
| 286 |
+
|
| 287 |
+
def __init__(
|
| 288 |
+
self,
|
| 289 |
+
input_channels: int,
|
| 290 |
+
dim: int,
|
| 291 |
+
intermediate_dim: int,
|
| 292 |
+
num_layers: int,
|
| 293 |
+
layer_scale_init_value: Optional[float] = None,
|
| 294 |
+
condition_dim: Optional[int] = None,
|
| 295 |
+
):
|
| 296 |
+
super().__init__()
|
| 297 |
+
self.input_channels = input_channels
|
| 298 |
+
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
|
| 299 |
+
self.adanorm = condition_dim is not None
|
| 300 |
+
if condition_dim:
|
| 301 |
+
self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
|
| 302 |
+
else:
|
| 303 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 304 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
|
| 305 |
+
self.convnext = nn.ModuleList(
|
| 306 |
+
[
|
| 307 |
+
ConvNeXtBlock(
|
| 308 |
+
dim=dim,
|
| 309 |
+
intermediate_dim=intermediate_dim,
|
| 310 |
+
layer_scale_init_value=layer_scale_init_value,
|
| 311 |
+
condition_dim=condition_dim,
|
| 312 |
+
)
|
| 313 |
+
for _ in range(num_layers)
|
| 314 |
+
]
|
| 315 |
+
)
|
| 316 |
+
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
|
| 317 |
+
self.apply(self._init_weights)
|
| 318 |
+
|
| 319 |
+
def _init_weights(self, m):
|
| 320 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
| 321 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 322 |
+
nn.init.constant_(m.bias, 0)
|
| 323 |
+
|
| 324 |
+
def forward(self, x: torch.Tensor, condition: torch.Tensor = None) -> torch.Tensor:
|
| 325 |
+
x = self.embed(x)
|
| 326 |
+
if self.adanorm:
|
| 327 |
+
assert condition is not None
|
| 328 |
+
x = self.norm(x.transpose(1, 2), condition)
|
| 329 |
+
else:
|
| 330 |
+
x = self.norm(x.transpose(1, 2))
|
| 331 |
+
x = x.transpose(1, 2)
|
| 332 |
+
for conv_block in self.convnext:
|
| 333 |
+
x = conv_block(x, condition)
|
| 334 |
+
x = self.final_layer_norm(x.transpose(1, 2))
|
| 335 |
+
return x
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
class VocosResNetBackbone(Backbone):
|
| 339 |
+
"""
|
| 340 |
+
Vocos backbone module built with ResBlocks.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
input_channels (int): Number of input features channels.
|
| 344 |
+
dim (int): Hidden dimension of the model.
|
| 345 |
+
num_blocks (int): Number of ResBlock1 blocks.
|
| 346 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
|
| 347 |
+
"""
|
| 348 |
+
|
| 349 |
+
def __init__(
|
| 350 |
+
self,
|
| 351 |
+
input_channels,
|
| 352 |
+
dim,
|
| 353 |
+
num_blocks,
|
| 354 |
+
layer_scale_init_value=None,
|
| 355 |
+
):
|
| 356 |
+
super().__init__()
|
| 357 |
+
self.input_channels = input_channels
|
| 358 |
+
self.embed = weight_norm(
|
| 359 |
+
nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
|
| 360 |
+
)
|
| 361 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
|
| 362 |
+
self.resnet = nn.Sequential(
|
| 363 |
+
*[
|
| 364 |
+
ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
|
| 365 |
+
for _ in range(num_blocks)
|
| 366 |
+
]
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 370 |
+
x = self.embed(x)
|
| 371 |
+
x = self.resnet(x)
|
| 372 |
+
x = x.transpose(1, 2)
|
| 373 |
+
return x
|
models/bicodec_tokenizer/modules/encoder_decoder/feat_decoder.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SparkAudio
|
| 2 |
+
# 2025 Xinsheng Wang (w.xinshawn@gmail.com)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
|
| 20 |
+
from typing import List
|
| 21 |
+
|
| 22 |
+
from ..blocks.vocos import VocosBackbone
|
| 23 |
+
from ..blocks.samper import SamplingBlock
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Decoder(nn.Module):
|
| 27 |
+
"""Decoder module with convnext and upsampling blocks
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
sample_ratios (List[int]): sample ratios
|
| 31 |
+
example: [2, 2] means downsample by 2x and then upsample by 2x
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
input_channels: int,
|
| 37 |
+
vocos_dim: int,
|
| 38 |
+
vocos_intermediate_dim: int,
|
| 39 |
+
vocos_num_layers: int,
|
| 40 |
+
out_channels: int,
|
| 41 |
+
condition_dim: int = None,
|
| 42 |
+
sample_ratios: List[int] = [1, 1],
|
| 43 |
+
use_tanh_at_final: bool = False,
|
| 44 |
+
):
|
| 45 |
+
super().__init__()
|
| 46 |
+
|
| 47 |
+
self.linear_pre = nn.Linear(input_channels, vocos_dim)
|
| 48 |
+
modules = [
|
| 49 |
+
nn.Sequential(
|
| 50 |
+
SamplingBlock(
|
| 51 |
+
dim=vocos_dim,
|
| 52 |
+
groups=vocos_dim,
|
| 53 |
+
upsample_scale=ratio,
|
| 54 |
+
),
|
| 55 |
+
VocosBackbone(
|
| 56 |
+
input_channels=vocos_dim,
|
| 57 |
+
dim=vocos_dim,
|
| 58 |
+
intermediate_dim=vocos_intermediate_dim,
|
| 59 |
+
num_layers=2,
|
| 60 |
+
condition_dim=None,
|
| 61 |
+
),
|
| 62 |
+
)
|
| 63 |
+
for ratio in sample_ratios
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
self.downsample = nn.Sequential(*modules)
|
| 67 |
+
|
| 68 |
+
self.vocos_backbone = VocosBackbone(
|
| 69 |
+
input_channels=vocos_dim,
|
| 70 |
+
dim=vocos_dim,
|
| 71 |
+
intermediate_dim=vocos_intermediate_dim,
|
| 72 |
+
num_layers=vocos_num_layers,
|
| 73 |
+
condition_dim=condition_dim,
|
| 74 |
+
)
|
| 75 |
+
self.linear = nn.Linear(vocos_dim, out_channels)
|
| 76 |
+
self.use_tanh_at_final = use_tanh_at_final
|
| 77 |
+
|
| 78 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor = None):
|
| 79 |
+
"""encoder forward.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
x (torch.Tensor): (batch_size, input_channels, length)
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
x (torch.Tensor): (batch_size, encode_channels, length)
|
| 86 |
+
"""
|
| 87 |
+
x = self.linear_pre(x.transpose(1, 2))
|
| 88 |
+
x = self.downsample(x).transpose(1, 2)
|
| 89 |
+
x = self.vocos_backbone(x, condition=c)
|
| 90 |
+
x = self.linear(x).transpose(1, 2)
|
| 91 |
+
if self.use_tanh_at_final:
|
| 92 |
+
x = torch.tanh(x)
|
| 93 |
+
|
| 94 |
+
return x
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# test
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
|
| 100 |
+
condition = torch.randn(8, 256)
|
| 101 |
+
decoder = Decoder(
|
| 102 |
+
input_channels=1024,
|
| 103 |
+
vocos_dim=384,
|
| 104 |
+
vocos_intermediate_dim=2048,
|
| 105 |
+
vocos_num_layers=12,
|
| 106 |
+
out_channels=256,
|
| 107 |
+
condition_dim=256,
|
| 108 |
+
sample_ratios=[2, 2],
|
| 109 |
+
)
|
| 110 |
+
output = decoder(test_input, condition)
|
| 111 |
+
print(output.shape) # torch.Size([8, 256, 200])
|
| 112 |
+
if output.shape == torch.Size([8, 256, 200]):
|
| 113 |
+
print("Decoder test passed")
|
| 114 |
+
else:
|
| 115 |
+
print("Decoder test failed")
|
models/bicodec_tokenizer/modules/encoder_decoder/feat_encoder.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SparkAudio
|
| 2 |
+
# 2025 Xinsheng Wang (w.xinshawn@gmail.com)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
|
| 20 |
+
from typing import List
|
| 21 |
+
import sys
|
| 22 |
+
sys.path.append("../../../..")
|
| 23 |
+
sys.path.append("../../../../..")
|
| 24 |
+
from ..blocks.vocos import VocosBackbone
|
| 25 |
+
from ..blocks.samper import SamplingBlock
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Encoder(nn.Module):
|
| 29 |
+
"""Encoder module with convnext and downsampling blocks"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
input_channels: int,
|
| 34 |
+
vocos_dim: int,
|
| 35 |
+
vocos_intermediate_dim: int,
|
| 36 |
+
vocos_num_layers: int,
|
| 37 |
+
out_channels: int,
|
| 38 |
+
sample_ratios: List[int] = [1, 1],
|
| 39 |
+
):
|
| 40 |
+
super().__init__()
|
| 41 |
+
"""
|
| 42 |
+
Encoder module with VocosBackbone and sampling blocks.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
sample_ratios (List[int]): sample ratios
|
| 46 |
+
example: [2, 2] means downsample by 2x and then upsample by 2x
|
| 47 |
+
"""
|
| 48 |
+
self.encoder = VocosBackbone(
|
| 49 |
+
input_channels=input_channels,
|
| 50 |
+
dim=vocos_dim,
|
| 51 |
+
intermediate_dim=vocos_intermediate_dim,
|
| 52 |
+
num_layers=vocos_num_layers,
|
| 53 |
+
condition_dim=None,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
modules = [
|
| 57 |
+
nn.Sequential(
|
| 58 |
+
SamplingBlock(
|
| 59 |
+
dim=vocos_dim,
|
| 60 |
+
groups=vocos_dim,
|
| 61 |
+
downsample_scale=ratio,
|
| 62 |
+
),
|
| 63 |
+
VocosBackbone(
|
| 64 |
+
input_channels=vocos_dim,
|
| 65 |
+
dim=vocos_dim,
|
| 66 |
+
intermediate_dim=vocos_intermediate_dim,
|
| 67 |
+
num_layers=2,
|
| 68 |
+
condition_dim=None,
|
| 69 |
+
),
|
| 70 |
+
)
|
| 71 |
+
for ratio in sample_ratios
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
self.downsample = nn.Sequential(*modules)
|
| 75 |
+
|
| 76 |
+
self.project = nn.Linear(vocos_dim, out_channels)
|
| 77 |
+
|
| 78 |
+
def forward(self, x: torch.Tensor, *args):
|
| 79 |
+
"""
|
| 80 |
+
Args:
|
| 81 |
+
x (torch.Tensor): (batch_size, input_channels, length)
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
x (torch.Tensor): (batch_size, encode_channels, length)
|
| 85 |
+
"""
|
| 86 |
+
x = self.encoder(x)
|
| 87 |
+
x = self.downsample(x)
|
| 88 |
+
x = self.project(x)
|
| 89 |
+
return x.transpose(1, 2)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# test
|
| 93 |
+
if __name__ == "__main__":
|
| 94 |
+
test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
|
| 95 |
+
encoder = Encoder(
|
| 96 |
+
input_channels=1024,
|
| 97 |
+
vocos_dim=384,
|
| 98 |
+
vocos_intermediate_dim=2048,
|
| 99 |
+
vocos_num_layers=12,
|
| 100 |
+
out_channels=256,
|
| 101 |
+
sample_ratios=[2, 2],
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
output = encoder(test_input)
|
| 105 |
+
print(output.shape) # torch.Size([8, 256, 12])
|
| 106 |
+
if output.shape == torch.Size([8, 256, 12]):
|
| 107 |
+
print("test successful")
|
models/bicodec_tokenizer/modules/encoder_decoder/wave_generator.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Xinsheng Wang (w.xinshawn@gmail.com)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
|
| 20 |
+
from ..blocks.layers import (
|
| 21 |
+
Snake1d,
|
| 22 |
+
WNConv1d,
|
| 23 |
+
ResidualUnit,
|
| 24 |
+
WNConvTranspose1d,
|
| 25 |
+
init_weights,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class DecoderBlock(nn.Module):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
input_dim: int = 16,
|
| 33 |
+
output_dim: int = 8,
|
| 34 |
+
kernel_size: int = 2,
|
| 35 |
+
stride: int = 1,
|
| 36 |
+
):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.block = nn.Sequential(
|
| 39 |
+
Snake1d(input_dim),
|
| 40 |
+
WNConvTranspose1d(
|
| 41 |
+
input_dim,
|
| 42 |
+
output_dim,
|
| 43 |
+
kernel_size=kernel_size,
|
| 44 |
+
stride=stride,
|
| 45 |
+
padding=(kernel_size - stride) // 2,
|
| 46 |
+
),
|
| 47 |
+
ResidualUnit(output_dim, dilation=1),
|
| 48 |
+
ResidualUnit(output_dim, dilation=3),
|
| 49 |
+
ResidualUnit(output_dim, dilation=9),
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
return self.block(x)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class WaveGenerator(nn.Module):
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
input_channel,
|
| 60 |
+
channels,
|
| 61 |
+
rates,
|
| 62 |
+
kernel_sizes,
|
| 63 |
+
d_out: int = 1,
|
| 64 |
+
):
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
# Add first conv layer
|
| 68 |
+
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
| 69 |
+
|
| 70 |
+
# Add upsampling + MRF blocks
|
| 71 |
+
for i, (kernel_size, stride) in enumerate(zip(kernel_sizes, rates)):
|
| 72 |
+
input_dim = channels // 2**i
|
| 73 |
+
output_dim = channels // 2 ** (i + 1)
|
| 74 |
+
layers += [DecoderBlock(input_dim, output_dim, kernel_size, stride)]
|
| 75 |
+
|
| 76 |
+
# Add final conv layer
|
| 77 |
+
layers += [
|
| 78 |
+
Snake1d(output_dim),
|
| 79 |
+
WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
| 80 |
+
nn.Tanh(),
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
self.model = nn.Sequential(*layers)
|
| 84 |
+
|
| 85 |
+
self.apply(init_weights)
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
return self.model(x)
|
models/bicodec_tokenizer/modules/fsq/finite_scalar_quantization.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
|
| 3 |
+
Code adapted from Jax version in Appendix A.1
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
from functools import wraps, partial
|
| 8 |
+
from contextlib import nullcontext
|
| 9 |
+
from typing import List, Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch.nn import Module
|
| 14 |
+
from torch import Tensor, int32
|
| 15 |
+
from torch.amp import autocast
|
| 16 |
+
|
| 17 |
+
from einops import rearrange, pack, unpack
|
| 18 |
+
|
| 19 |
+
# helper functions
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def exists(v):
|
| 23 |
+
return v is not None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def default(*args):
|
| 27 |
+
for arg in args:
|
| 28 |
+
if exists(arg):
|
| 29 |
+
return arg
|
| 30 |
+
return None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def maybe(fn):
|
| 34 |
+
@wraps(fn)
|
| 35 |
+
def inner(x, *args, **kwargs):
|
| 36 |
+
if not exists(x):
|
| 37 |
+
return x
|
| 38 |
+
return fn(x, *args, **kwargs)
|
| 39 |
+
|
| 40 |
+
return inner
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def pack_one(t, pattern):
|
| 44 |
+
return pack([t], pattern)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def unpack_one(t, ps, pattern):
|
| 48 |
+
return unpack(t, ps, pattern)[0]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# tensor helpers
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def round_ste(z: Tensor) -> Tensor:
|
| 55 |
+
"""Round with straight through gradients."""
|
| 56 |
+
zhat = z.round()
|
| 57 |
+
return z + (zhat - z).detach()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# main class
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class FSQ(Module):
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
levels: List[int],
|
| 67 |
+
dim: int | None = None,
|
| 68 |
+
num_codebooks=1,
|
| 69 |
+
keep_num_codebooks_dim: bool | None = None,
|
| 70 |
+
scale: float | None = None,
|
| 71 |
+
allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
|
| 72 |
+
channel_first: bool = False,
|
| 73 |
+
projection_has_bias: bool = True,
|
| 74 |
+
return_indices=True,
|
| 75 |
+
force_quantization_f32=True,
|
| 76 |
+
):
|
| 77 |
+
super().__init__()
|
| 78 |
+
_levels = torch.tensor(levels, dtype=int32)
|
| 79 |
+
self.register_buffer("_levels", _levels, persistent=False)
|
| 80 |
+
|
| 81 |
+
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
|
| 82 |
+
self.register_buffer("_basis", _basis, persistent=False)
|
| 83 |
+
|
| 84 |
+
self.scale = scale
|
| 85 |
+
|
| 86 |
+
codebook_dim = len(levels)
|
| 87 |
+
self.codebook_dim = codebook_dim
|
| 88 |
+
|
| 89 |
+
effective_codebook_dim = codebook_dim * num_codebooks
|
| 90 |
+
self.num_codebooks = num_codebooks
|
| 91 |
+
self.effective_codebook_dim = effective_codebook_dim
|
| 92 |
+
|
| 93 |
+
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
|
| 94 |
+
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
|
| 95 |
+
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
| 96 |
+
|
| 97 |
+
self.dim = default(dim, len(_levels) * num_codebooks)
|
| 98 |
+
|
| 99 |
+
self.channel_first = channel_first
|
| 100 |
+
|
| 101 |
+
has_projections = self.dim != effective_codebook_dim
|
| 102 |
+
self.project_in = (
|
| 103 |
+
nn.Linear(self.dim, effective_codebook_dim, bias=projection_has_bias)
|
| 104 |
+
if has_projections
|
| 105 |
+
else nn.Identity()
|
| 106 |
+
)
|
| 107 |
+
self.project_out = (
|
| 108 |
+
nn.Linear(effective_codebook_dim, self.dim, bias=projection_has_bias)
|
| 109 |
+
if has_projections
|
| 110 |
+
else nn.Identity()
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
self.has_projections = has_projections
|
| 114 |
+
|
| 115 |
+
self.return_indices = return_indices
|
| 116 |
+
if return_indices:
|
| 117 |
+
self.codebook_size = self._levels.prod().item()
|
| 118 |
+
implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size))
|
| 119 |
+
self.register_buffer(
|
| 120 |
+
"implicit_codebook", implicit_codebook, persistent=False
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
self.allowed_dtypes = allowed_dtypes
|
| 124 |
+
self.force_quantization_f32 = force_quantization_f32
|
| 125 |
+
|
| 126 |
+
def bound(self, z, eps: float = 1e-3):
|
| 127 |
+
"""Bound `z`, an array of shape (..., d)."""
|
| 128 |
+
half_l = (self._levels - 1) * (1 + eps) / 2
|
| 129 |
+
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
|
| 130 |
+
shift = (offset / half_l).atanh()
|
| 131 |
+
return (z + shift).tanh() * half_l - offset
|
| 132 |
+
|
| 133 |
+
def quantize(self, z):
|
| 134 |
+
"""Quantizes z, returns quantized zhat, same shape as z."""
|
| 135 |
+
quantized = round_ste(self.bound(z))
|
| 136 |
+
half_width = self._levels // 2 # Renormalize to [-1, 1].
|
| 137 |
+
return quantized / half_width
|
| 138 |
+
|
| 139 |
+
def _scale_and_shift(self, zhat_normalized):
|
| 140 |
+
half_width = self._levels // 2
|
| 141 |
+
return (zhat_normalized * half_width) + half_width
|
| 142 |
+
|
| 143 |
+
def _scale_and_shift_inverse(self, zhat):
|
| 144 |
+
half_width = self._levels // 2
|
| 145 |
+
return (zhat - half_width) / half_width
|
| 146 |
+
|
| 147 |
+
def _indices_to_codes(self, indices):
|
| 148 |
+
level_indices = self.indices_to_level_indices(indices)
|
| 149 |
+
codes = self._scale_and_shift_inverse(level_indices)
|
| 150 |
+
return codes
|
| 151 |
+
|
| 152 |
+
def codes_to_indices(self, zhat):
|
| 153 |
+
"""Converts a `code` to an index in the codebook."""
|
| 154 |
+
assert zhat.shape[-1] == self.codebook_dim
|
| 155 |
+
zhat = self._scale_and_shift(zhat)
|
| 156 |
+
return (zhat * self._basis).sum(dim=-1).to(int32)
|
| 157 |
+
|
| 158 |
+
def indices_to_level_indices(self, indices):
|
| 159 |
+
"""Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings"""
|
| 160 |
+
indices = rearrange(indices, "... -> ... 1")
|
| 161 |
+
codes_non_centered = (indices // self._basis) % self._levels
|
| 162 |
+
return codes_non_centered
|
| 163 |
+
|
| 164 |
+
def indices_to_codes(self, indices):
|
| 165 |
+
"""Inverse of `codes_to_indices`."""
|
| 166 |
+
assert exists(indices)
|
| 167 |
+
|
| 168 |
+
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
| 169 |
+
|
| 170 |
+
codes = self._indices_to_codes(indices)
|
| 171 |
+
|
| 172 |
+
if self.keep_num_codebooks_dim:
|
| 173 |
+
codes = rearrange(codes, "... c d -> ... (c d)")
|
| 174 |
+
|
| 175 |
+
codes = self.project_out(codes)
|
| 176 |
+
|
| 177 |
+
if is_img_or_video or self.channel_first:
|
| 178 |
+
codes = rearrange(codes, "b ... d -> b d ...")
|
| 179 |
+
|
| 180 |
+
return codes
|
| 181 |
+
|
| 182 |
+
def forward(self, z):
|
| 183 |
+
"""
|
| 184 |
+
einstein notation
|
| 185 |
+
b - batch
|
| 186 |
+
n - sequence (or flattened spatial dimensions)
|
| 187 |
+
d - feature dimension
|
| 188 |
+
c - number of codebook dim
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
is_img_or_video = z.ndim >= 4
|
| 192 |
+
need_move_channel_last = is_img_or_video or self.channel_first
|
| 193 |
+
|
| 194 |
+
# standardize image or video into (batch, seq, dimension)
|
| 195 |
+
|
| 196 |
+
if need_move_channel_last:
|
| 197 |
+
z = rearrange(z, "b d ... -> b ... d")
|
| 198 |
+
z, ps = pack_one(z, "b * d")
|
| 199 |
+
|
| 200 |
+
assert (
|
| 201 |
+
z.shape[-1] == self.dim
|
| 202 |
+
), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
|
| 203 |
+
|
| 204 |
+
z = self.project_in(z)
|
| 205 |
+
|
| 206 |
+
z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
|
| 207 |
+
|
| 208 |
+
# whether to force quantization step to be full precision or not
|
| 209 |
+
|
| 210 |
+
force_f32 = self.force_quantization_f32
|
| 211 |
+
quantization_context = (
|
| 212 |
+
partial(autocast, "cuda", enabled=False) if force_f32 else nullcontext
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
with quantization_context():
|
| 216 |
+
orig_dtype = z.dtype
|
| 217 |
+
|
| 218 |
+
if force_f32 and orig_dtype not in self.allowed_dtypes:
|
| 219 |
+
z = z.float()
|
| 220 |
+
|
| 221 |
+
codes = self.quantize(z)
|
| 222 |
+
|
| 223 |
+
# returning indices could be optional
|
| 224 |
+
|
| 225 |
+
indices = None
|
| 226 |
+
|
| 227 |
+
if self.return_indices:
|
| 228 |
+
indices = self.codes_to_indices(codes)
|
| 229 |
+
|
| 230 |
+
codes = rearrange(codes, "b n c d -> b n (c d)")
|
| 231 |
+
|
| 232 |
+
codes = codes.type(orig_dtype)
|
| 233 |
+
|
| 234 |
+
# project out
|
| 235 |
+
|
| 236 |
+
out = self.project_out(codes)
|
| 237 |
+
|
| 238 |
+
# reconstitute image or video dimensions
|
| 239 |
+
|
| 240 |
+
if need_move_channel_last:
|
| 241 |
+
out = unpack_one(out, ps, "b * d")
|
| 242 |
+
out = rearrange(out, "b ... d -> b d ...")
|
| 243 |
+
|
| 244 |
+
indices = maybe(unpack_one)(indices, ps, "b * c")
|
| 245 |
+
|
| 246 |
+
if not self.keep_num_codebooks_dim and self.return_indices:
|
| 247 |
+
indices = maybe(rearrange)(indices, "... 1 -> ...")
|
| 248 |
+
|
| 249 |
+
# return quantized output and indices
|
| 250 |
+
|
| 251 |
+
return out, indices
|
models/bicodec_tokenizer/modules/fsq/residual_fsq.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
|
| 6 |
+
from typing import List
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.nn import Module
|
| 9 |
+
from torch.amp import autocast
|
| 10 |
+
from einx import get_at
|
| 11 |
+
from einops import rearrange, reduce, pack, unpack
|
| 12 |
+
|
| 13 |
+
from .finite_scalar_quantization import FSQ
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def exists(val):
|
| 17 |
+
return val is not None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def first(l):
|
| 21 |
+
return l[0]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def default(val, d):
|
| 25 |
+
return val if exists(val) else d
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def round_up_multiple(num, mult):
|
| 29 |
+
return ceil(num / mult) * mult
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# distributed helpers
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def is_distributed():
|
| 36 |
+
return dist.is_initialized() and dist.get_world_size() > 1
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_maybe_sync_seed(device, max_size=10_000):
|
| 40 |
+
rand_int = torch.randint(0, max_size, (), device=device)
|
| 41 |
+
|
| 42 |
+
if is_distributed():
|
| 43 |
+
dist.all_reduce(rand_int)
|
| 44 |
+
|
| 45 |
+
return rand_int.item()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ResidualFSQ(Module):
|
| 49 |
+
"""Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
*,
|
| 54 |
+
levels: List[int],
|
| 55 |
+
num_quantizers,
|
| 56 |
+
dim=None,
|
| 57 |
+
is_channel_first=False,
|
| 58 |
+
quantize_dropout=False,
|
| 59 |
+
quantize_dropout_cutoff_index=0,
|
| 60 |
+
quantize_dropout_multiple_of=1,
|
| 61 |
+
**kwargs,
|
| 62 |
+
):
|
| 63 |
+
super().__init__()
|
| 64 |
+
codebook_dim = len(levels)
|
| 65 |
+
dim = default(dim, codebook_dim)
|
| 66 |
+
|
| 67 |
+
requires_projection = codebook_dim != dim
|
| 68 |
+
self.project_in = (
|
| 69 |
+
nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
|
| 70 |
+
)
|
| 71 |
+
self.project_out = (
|
| 72 |
+
nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
|
| 73 |
+
)
|
| 74 |
+
self.has_projections = requires_projection
|
| 75 |
+
|
| 76 |
+
self.is_channel_first = is_channel_first
|
| 77 |
+
self.num_quantizers = num_quantizers
|
| 78 |
+
|
| 79 |
+
self.levels = levels
|
| 80 |
+
self.layers = nn.ModuleList([])
|
| 81 |
+
|
| 82 |
+
levels_tensor = torch.Tensor(levels)
|
| 83 |
+
|
| 84 |
+
scales = []
|
| 85 |
+
|
| 86 |
+
for ind in range(num_quantizers):
|
| 87 |
+
scales.append((levels_tensor - 1) ** -ind)
|
| 88 |
+
|
| 89 |
+
fsq = FSQ(levels=levels, dim=codebook_dim, **kwargs)
|
| 90 |
+
|
| 91 |
+
self.layers.append(fsq)
|
| 92 |
+
|
| 93 |
+
assert all([not fsq.has_projections for fsq in self.layers])
|
| 94 |
+
|
| 95 |
+
self.codebook_size = self.layers[0].codebook_size
|
| 96 |
+
|
| 97 |
+
self.register_buffer("scales", torch.stack(scales), persistent=False)
|
| 98 |
+
|
| 99 |
+
self.quantize_dropout = quantize_dropout and num_quantizers > 1
|
| 100 |
+
|
| 101 |
+
assert quantize_dropout_cutoff_index >= 0
|
| 102 |
+
|
| 103 |
+
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
|
| 104 |
+
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def codebooks(self):
|
| 108 |
+
codebooks = [layer.implicit_codebook for layer in self.layers]
|
| 109 |
+
codebooks = torch.stack(codebooks, dim=0)
|
| 110 |
+
return codebooks
|
| 111 |
+
|
| 112 |
+
def get_codes_from_indices(self, indices):
|
| 113 |
+
|
| 114 |
+
batch, quantize_dim = indices.shape[0], indices.shape[-1]
|
| 115 |
+
|
| 116 |
+
# may also receive indices in the shape of 'b h w q' (accept_image_fmap)
|
| 117 |
+
|
| 118 |
+
indices, ps = pack([indices], "b * q")
|
| 119 |
+
|
| 120 |
+
# because of quantize dropout, one can pass in indices that are coarse
|
| 121 |
+
# and the network should be able to reconstruct
|
| 122 |
+
|
| 123 |
+
if quantize_dim < self.num_quantizers:
|
| 124 |
+
assert (
|
| 125 |
+
self.quantize_dropout > 0.0
|
| 126 |
+
), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
|
| 127 |
+
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)
|
| 128 |
+
|
| 129 |
+
# take care of quantizer dropout
|
| 130 |
+
|
| 131 |
+
mask = indices == -1
|
| 132 |
+
indices = indices.masked_fill(
|
| 133 |
+
mask, 0
|
| 134 |
+
) # have it fetch a dummy code to be masked out later
|
| 135 |
+
|
| 136 |
+
all_codes = get_at("q [c] d, b n q -> q b n d", self.codebooks, indices)
|
| 137 |
+
|
| 138 |
+
# mask out any codes that were dropout-ed
|
| 139 |
+
|
| 140 |
+
all_codes = all_codes.masked_fill(rearrange(mask, "b n q -> q b n 1"), 0.0)
|
| 141 |
+
|
| 142 |
+
# scale the codes
|
| 143 |
+
|
| 144 |
+
scales = rearrange(self.scales, "q d -> q 1 1 d")
|
| 145 |
+
all_codes = all_codes * scales
|
| 146 |
+
|
| 147 |
+
# if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
|
| 148 |
+
|
| 149 |
+
(all_codes,) = unpack(all_codes, ps, "q b * d")
|
| 150 |
+
|
| 151 |
+
return all_codes
|
| 152 |
+
|
| 153 |
+
def get_output_from_indices(self, indices):
|
| 154 |
+
codes = self.get_codes_from_indices(indices)
|
| 155 |
+
codes_summed = reduce(codes, "q ... -> ...", "sum")
|
| 156 |
+
return self.project_out(codes_summed)
|
| 157 |
+
|
| 158 |
+
def forward(self, x, return_all_codes=False, rand_quantize_dropout_fixed_seed=None):
|
| 159 |
+
num_quant, quant_dropout_multiple_of, device = (
|
| 160 |
+
self.num_quantizers,
|
| 161 |
+
self.quantize_dropout_multiple_of,
|
| 162 |
+
x.device,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# handle channel first
|
| 166 |
+
|
| 167 |
+
if self.is_channel_first:
|
| 168 |
+
x = rearrange(x, "b d ... -> b ... d")
|
| 169 |
+
x, ps = pack([x], "b * d")
|
| 170 |
+
|
| 171 |
+
# maybe project in
|
| 172 |
+
|
| 173 |
+
x = self.project_in(x)
|
| 174 |
+
|
| 175 |
+
quantized_out = 0.0
|
| 176 |
+
residual = x
|
| 177 |
+
|
| 178 |
+
all_indices = []
|
| 179 |
+
|
| 180 |
+
should_quantize_dropout = self.training and self.quantize_dropout
|
| 181 |
+
|
| 182 |
+
# sample a layer index at which to dropout further residual quantization
|
| 183 |
+
# also prepare null indices
|
| 184 |
+
|
| 185 |
+
if should_quantize_dropout:
|
| 186 |
+
|
| 187 |
+
# check if seed is manually passed in
|
| 188 |
+
|
| 189 |
+
if not exists(rand_quantize_dropout_fixed_seed):
|
| 190 |
+
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
|
| 191 |
+
|
| 192 |
+
rand = random.Random(rand_quantize_dropout_fixed_seed)
|
| 193 |
+
|
| 194 |
+
rand_quantize_dropout_index = rand.randrange(
|
| 195 |
+
self.quantize_dropout_cutoff_index, num_quant
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
if quant_dropout_multiple_of != 1:
|
| 199 |
+
rand_quantize_dropout_index = (
|
| 200 |
+
round_up_multiple(
|
| 201 |
+
rand_quantize_dropout_index + 1, quant_dropout_multiple_of
|
| 202 |
+
)
|
| 203 |
+
- 1
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
null_indices = torch.full(
|
| 207 |
+
x.shape[:2], -1.0, device=device, dtype=torch.long
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# go through the layers
|
| 211 |
+
|
| 212 |
+
with autocast("cuda", enabled=False):
|
| 213 |
+
for quantizer_index, (layer, scale) in enumerate(
|
| 214 |
+
zip(self.layers, self.scales)
|
| 215 |
+
):
|
| 216 |
+
|
| 217 |
+
if (
|
| 218 |
+
should_quantize_dropout
|
| 219 |
+
and quantizer_index > rand_quantize_dropout_index
|
| 220 |
+
):
|
| 221 |
+
all_indices.append(null_indices)
|
| 222 |
+
continue
|
| 223 |
+
|
| 224 |
+
quantized, indices = layer(residual / scale)
|
| 225 |
+
|
| 226 |
+
quantized = quantized * scale
|
| 227 |
+
|
| 228 |
+
residual = residual - quantized.detach()
|
| 229 |
+
quantized_out = quantized_out + quantized
|
| 230 |
+
|
| 231 |
+
all_indices.append(indices)
|
| 232 |
+
|
| 233 |
+
# project out, if needed
|
| 234 |
+
|
| 235 |
+
quantized_out = self.project_out(quantized_out)
|
| 236 |
+
|
| 237 |
+
# stack all indices
|
| 238 |
+
|
| 239 |
+
all_indices = torch.stack(all_indices, dim=-1)
|
| 240 |
+
|
| 241 |
+
# channel first out
|
| 242 |
+
|
| 243 |
+
if self.is_channel_first:
|
| 244 |
+
(quantized_out,) = unpack(quantized_out, ps, "b * d")
|
| 245 |
+
(all_indices,) = unpack(all_indices, ps, "b * d")
|
| 246 |
+
|
| 247 |
+
quantized_out = rearrange(quantized_out, "b ... d -> b d ...")
|
| 248 |
+
all_indices = rearrange(all_indices, "b ... d -> b d ...")
|
| 249 |
+
|
| 250 |
+
# return
|
| 251 |
+
|
| 252 |
+
ret = (quantized_out, all_indices)
|
| 253 |
+
|
| 254 |
+
if not return_all_codes:
|
| 255 |
+
return ret
|
| 256 |
+
|
| 257 |
+
# whether to return all codes from all codebooks across layers
|
| 258 |
+
|
| 259 |
+
all_codes = self.get_codes_from_indices(all_indices)
|
| 260 |
+
|
| 261 |
+
# will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
|
| 262 |
+
|
| 263 |
+
return (*ret, all_codes)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# grouped residual fsq
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class GroupedResidualFSQ(Module):
|
| 270 |
+
def __init__(self, *, dim, groups=1, accept_image_fmap=False, **kwargs):
|
| 271 |
+
super().__init__()
|
| 272 |
+
self.dim = dim
|
| 273 |
+
self.groups = groups
|
| 274 |
+
assert (dim % groups) == 0
|
| 275 |
+
dim_per_group = dim // groups
|
| 276 |
+
|
| 277 |
+
self.accept_image_fmap = accept_image_fmap
|
| 278 |
+
|
| 279 |
+
self.rvqs = nn.ModuleList([])
|
| 280 |
+
|
| 281 |
+
for _ in range(groups):
|
| 282 |
+
self.rvqs.append(ResidualFSQ(dim=dim_per_group, **kwargs))
|
| 283 |
+
|
| 284 |
+
self.codebook_size = self.rvqs[0].codebook_size
|
| 285 |
+
|
| 286 |
+
@property
|
| 287 |
+
def codebooks(self):
|
| 288 |
+
return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
|
| 289 |
+
|
| 290 |
+
@property
|
| 291 |
+
def split_dim(self):
|
| 292 |
+
return 1 if self.accept_image_fmap else -1
|
| 293 |
+
|
| 294 |
+
def get_codes_from_indices(self, indices):
|
| 295 |
+
codes = tuple(
|
| 296 |
+
rvq.get_codes_from_indices(chunk_indices)
|
| 297 |
+
for rvq, chunk_indices in zip(self.rvqs, indices)
|
| 298 |
+
)
|
| 299 |
+
return torch.stack(codes)
|
| 300 |
+
|
| 301 |
+
def get_output_from_indices(self, indices):
|
| 302 |
+
outputs = tuple(
|
| 303 |
+
rvq.get_output_from_indices(chunk_indices)
|
| 304 |
+
for rvq, chunk_indices in zip(self.rvqs, indices)
|
| 305 |
+
)
|
| 306 |
+
return torch.cat(outputs, dim=self.split_dim)
|
| 307 |
+
|
| 308 |
+
def forward(self, x, return_all_codes=False):
|
| 309 |
+
shape, split_dim, device = x.shape, self.split_dim, x.device
|
| 310 |
+
assert shape[split_dim] == self.dim
|
| 311 |
+
|
| 312 |
+
# split the feature dimension into groups
|
| 313 |
+
|
| 314 |
+
x = x.chunk(self.groups, dim=split_dim)
|
| 315 |
+
|
| 316 |
+
forward_kwargs = dict(
|
| 317 |
+
return_all_codes=return_all_codes,
|
| 318 |
+
rand_quantize_dropout_fixed_seed=(
|
| 319 |
+
get_maybe_sync_seed(device) if self.training else None
|
| 320 |
+
),
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
# invoke residual vq on each group
|
| 324 |
+
|
| 325 |
+
out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
|
| 326 |
+
out = tuple(zip(*out))
|
| 327 |
+
|
| 328 |
+
# otherwise, get all the zipped outputs and combine them
|
| 329 |
+
|
| 330 |
+
quantized, all_indices, *maybe_all_codes = out
|
| 331 |
+
|
| 332 |
+
quantized = torch.cat(quantized, dim=split_dim)
|
| 333 |
+
all_indices = torch.stack(all_indices)
|
| 334 |
+
|
| 335 |
+
ret = (quantized, all_indices, *maybe_all_codes)
|
| 336 |
+
return ret
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
if __name__ == "__main__":
|
| 340 |
+
model = ResidualFSQ(
|
| 341 |
+
levels=[4, 4, 4, 4, 4, 4],
|
| 342 |
+
num_quantizers=1,
|
| 343 |
+
dim=30,
|
| 344 |
+
is_channel_first=True,
|
| 345 |
+
quantize_dropout=False,
|
| 346 |
+
)
|
| 347 |
+
x = torch.randn(2, 30, 10)
|
| 348 |
+
quantize, embed_ind = model(x)
|
| 349 |
+
|
| 350 |
+
emb_from_ind = model.get_output_from_indices(embed_ind.transpose(1, 2))
|
| 351 |
+
|
| 352 |
+
print(quantize == emb_from_ind.transpose(1, 2))
|
| 353 |
+
|
| 354 |
+
print("quantize shape", quantize.shape)
|
| 355 |
+
print("embed_ind", embed_ind)
|
models/bicodec_tokenizer/modules/speaker/__init__.py
ADDED
|
File without changes
|
models/bicodec_tokenizer/modules/speaker/ecapa_tdnn.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 Zhengyang Chen (chenzhengyang117@gmail.com)
|
| 2 |
+
# 2022 Hongji Wang (jijijiang77@gmail.com)
|
| 3 |
+
# 2023 Bing Han (hanbing97@sjtu.edu.cn)
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
""" This implementation is adapted from github repo:
|
| 18 |
+
https://github.com/lawlict/ECAPA-TDNN.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
|
| 25 |
+
from . import pooling_layers
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Res2Conv1dReluBn(nn.Module):
|
| 29 |
+
"""
|
| 30 |
+
in_channels == out_channels == channels
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
channels,
|
| 36 |
+
kernel_size=1,
|
| 37 |
+
stride=1,
|
| 38 |
+
padding=0,
|
| 39 |
+
dilation=1,
|
| 40 |
+
bias=True,
|
| 41 |
+
scale=4,
|
| 42 |
+
):
|
| 43 |
+
super().__init__()
|
| 44 |
+
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
|
| 45 |
+
self.scale = scale
|
| 46 |
+
self.width = channels // scale
|
| 47 |
+
self.nums = scale if scale == 1 else scale - 1
|
| 48 |
+
|
| 49 |
+
self.convs = []
|
| 50 |
+
self.bns = []
|
| 51 |
+
for i in range(self.nums):
|
| 52 |
+
self.convs.append(
|
| 53 |
+
nn.Conv1d(
|
| 54 |
+
self.width,
|
| 55 |
+
self.width,
|
| 56 |
+
kernel_size,
|
| 57 |
+
stride,
|
| 58 |
+
padding,
|
| 59 |
+
dilation,
|
| 60 |
+
bias=bias,
|
| 61 |
+
)
|
| 62 |
+
)
|
| 63 |
+
self.bns.append(nn.BatchNorm1d(self.width))
|
| 64 |
+
self.convs = nn.ModuleList(self.convs)
|
| 65 |
+
self.bns = nn.ModuleList(self.bns)
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
out = []
|
| 69 |
+
spx = torch.split(x, self.width, 1)
|
| 70 |
+
sp = spx[0]
|
| 71 |
+
for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
|
| 72 |
+
# Order: conv -> relu -> bn
|
| 73 |
+
if i >= 1:
|
| 74 |
+
sp = sp + spx[i]
|
| 75 |
+
sp = conv(sp)
|
| 76 |
+
sp = bn(F.relu(sp))
|
| 77 |
+
out.append(sp)
|
| 78 |
+
if self.scale != 1:
|
| 79 |
+
out.append(spx[self.nums])
|
| 80 |
+
out = torch.cat(out, dim=1)
|
| 81 |
+
|
| 82 |
+
return out
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
""" Conv1d + BatchNorm1d + ReLU
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class Conv1dReluBn(nn.Module):
|
| 90 |
+
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
in_channels,
|
| 94 |
+
out_channels,
|
| 95 |
+
kernel_size=1,
|
| 96 |
+
stride=1,
|
| 97 |
+
padding=0,
|
| 98 |
+
dilation=1,
|
| 99 |
+
bias=True,
|
| 100 |
+
):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.conv = nn.Conv1d(
|
| 103 |
+
in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
|
| 104 |
+
)
|
| 105 |
+
self.bn = nn.BatchNorm1d(out_channels)
|
| 106 |
+
|
| 107 |
+
def forward(self, x):
|
| 108 |
+
return self.bn(F.relu(self.conv(x)))
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
""" The SE connection of 1D case.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class SE_Connect(nn.Module):
|
| 116 |
+
|
| 117 |
+
def __init__(self, channels, se_bottleneck_dim=128):
|
| 118 |
+
super().__init__()
|
| 119 |
+
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
|
| 120 |
+
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
|
| 121 |
+
|
| 122 |
+
def forward(self, x):
|
| 123 |
+
out = x.mean(dim=2)
|
| 124 |
+
out = F.relu(self.linear1(out))
|
| 125 |
+
out = torch.sigmoid(self.linear2(out))
|
| 126 |
+
out = x * out.unsqueeze(2)
|
| 127 |
+
|
| 128 |
+
return out
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
""" SE-Res2Block of the ECAPA-TDNN architecture.
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class SE_Res2Block(nn.Module):
|
| 136 |
+
|
| 137 |
+
def __init__(self, channels, kernel_size, stride, padding, dilation, scale):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.se_res2block = nn.Sequential(
|
| 140 |
+
Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
|
| 141 |
+
Res2Conv1dReluBn(
|
| 142 |
+
channels, kernel_size, stride, padding, dilation, scale=scale
|
| 143 |
+
),
|
| 144 |
+
Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
|
| 145 |
+
SE_Connect(channels),
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def forward(self, x):
|
| 149 |
+
return x + self.se_res2block(x)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class ECAPA_TDNN(nn.Module):
|
| 153 |
+
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
channels=512,
|
| 157 |
+
feat_dim=80,
|
| 158 |
+
embed_dim=192,
|
| 159 |
+
pooling_func="ASTP",
|
| 160 |
+
global_context_att=False,
|
| 161 |
+
emb_bn=False,
|
| 162 |
+
):
|
| 163 |
+
super().__init__()
|
| 164 |
+
|
| 165 |
+
self.layer1 = Conv1dReluBn(feat_dim, channels, kernel_size=5, padding=2)
|
| 166 |
+
self.layer2 = SE_Res2Block(
|
| 167 |
+
channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8
|
| 168 |
+
)
|
| 169 |
+
self.layer3 = SE_Res2Block(
|
| 170 |
+
channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8
|
| 171 |
+
)
|
| 172 |
+
self.layer4 = SE_Res2Block(
|
| 173 |
+
channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
cat_channels = channels * 3
|
| 177 |
+
out_channels = 512 * 3
|
| 178 |
+
self.conv = nn.Conv1d(cat_channels, out_channels, kernel_size=1)
|
| 179 |
+
self.pool = getattr(pooling_layers, pooling_func)(
|
| 180 |
+
in_dim=out_channels, global_context_att=global_context_att
|
| 181 |
+
)
|
| 182 |
+
self.pool_out_dim = self.pool.get_out_dim()
|
| 183 |
+
self.bn = nn.BatchNorm1d(self.pool_out_dim)
|
| 184 |
+
self.linear = nn.Linear(self.pool_out_dim, embed_dim)
|
| 185 |
+
self.emb_bn = emb_bn
|
| 186 |
+
if emb_bn: # better in SSL for SV
|
| 187 |
+
self.bn2 = nn.BatchNorm1d(embed_dim)
|
| 188 |
+
else:
|
| 189 |
+
self.bn2 = nn.Identity()
|
| 190 |
+
|
| 191 |
+
def forward(self, x, return_latent=False):
|
| 192 |
+
x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T)
|
| 193 |
+
|
| 194 |
+
out1 = self.layer1(x)
|
| 195 |
+
out2 = self.layer2(out1)
|
| 196 |
+
out3 = self.layer3(out2)
|
| 197 |
+
out4 = self.layer4(out3)
|
| 198 |
+
|
| 199 |
+
out = torch.cat([out2, out3, out4], dim=1)
|
| 200 |
+
latent = F.relu(self.conv(out))
|
| 201 |
+
out = self.bn(self.pool(latent))
|
| 202 |
+
out = self.linear(out)
|
| 203 |
+
if self.emb_bn:
|
| 204 |
+
out = self.bn2(out)
|
| 205 |
+
|
| 206 |
+
if return_latent:
|
| 207 |
+
return out, latent
|
| 208 |
+
return out
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def ECAPA_TDNN_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
|
| 212 |
+
return ECAPA_TDNN(
|
| 213 |
+
channels=1024,
|
| 214 |
+
feat_dim=feat_dim,
|
| 215 |
+
embed_dim=embed_dim,
|
| 216 |
+
pooling_func=pooling_func,
|
| 217 |
+
emb_bn=emb_bn,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def ECAPA_TDNN_GLOB_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
|
| 222 |
+
return ECAPA_TDNN(
|
| 223 |
+
channels=1024,
|
| 224 |
+
feat_dim=feat_dim,
|
| 225 |
+
embed_dim=embed_dim,
|
| 226 |
+
pooling_func=pooling_func,
|
| 227 |
+
global_context_att=True,
|
| 228 |
+
emb_bn=emb_bn,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def ECAPA_TDNN_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
|
| 233 |
+
return ECAPA_TDNN(
|
| 234 |
+
channels=512,
|
| 235 |
+
feat_dim=feat_dim,
|
| 236 |
+
embed_dim=embed_dim,
|
| 237 |
+
pooling_func=pooling_func,
|
| 238 |
+
emb_bn=emb_bn,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def ECAPA_TDNN_GLOB_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
|
| 243 |
+
return ECAPA_TDNN(
|
| 244 |
+
channels=512,
|
| 245 |
+
feat_dim=feat_dim,
|
| 246 |
+
embed_dim=embed_dim,
|
| 247 |
+
pooling_func=pooling_func,
|
| 248 |
+
global_context_att=True,
|
| 249 |
+
emb_bn=emb_bn,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
if __name__ == "__main__":
|
| 254 |
+
x = torch.zeros(1, 200, 100)
|
| 255 |
+
model = ECAPA_TDNN_GLOB_c512(feat_dim=100, embed_dim=256, pooling_func="ASTP")
|
| 256 |
+
model.eval()
|
| 257 |
+
out, latent = model(x, True)
|
| 258 |
+
print(out.shape)
|
| 259 |
+
print(latent.shape)
|
| 260 |
+
|
| 261 |
+
num_params = sum(param.numel() for param in model.parameters())
|
| 262 |
+
print("{} M".format(num_params / 1e6))
|
| 263 |
+
|
| 264 |
+
# from thop import profile
|
| 265 |
+
# x_np = torch.randn(1, 200, 80)
|
| 266 |
+
# flops, params = profile(model, inputs=(x_np, ))
|
| 267 |
+
# print("FLOPs: {} G, Params: {} M".format(flops / 1e9, params / 1e6))
|
models/bicodec_tokenizer/modules/speaker/perceiver_encoder.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SparkAudio
|
| 2 |
+
# 2025 Xinsheng Wang (w.xinshawn@gmail.com)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
|
| 17 |
+
|
| 18 |
+
from collections import namedtuple
|
| 19 |
+
from functools import wraps
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from einops import rearrange, repeat
|
| 24 |
+
from einops.layers.torch import Rearrange
|
| 25 |
+
from packaging import version
|
| 26 |
+
from torch import einsum, nn
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def exists(val):
|
| 30 |
+
return val is not None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def once(fn):
|
| 34 |
+
called = False
|
| 35 |
+
|
| 36 |
+
@wraps(fn)
|
| 37 |
+
def inner(x):
|
| 38 |
+
nonlocal called
|
| 39 |
+
if called:
|
| 40 |
+
return
|
| 41 |
+
called = True
|
| 42 |
+
return fn(x)
|
| 43 |
+
|
| 44 |
+
return inner
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
print_once = once(print)
|
| 48 |
+
|
| 49 |
+
# main class
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class Attend(nn.Module):
|
| 53 |
+
def __init__(self, dropout=0.0, causal=False, use_flash=False):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.dropout = dropout
|
| 56 |
+
self.attn_dropout = nn.Dropout(dropout)
|
| 57 |
+
|
| 58 |
+
self.causal = causal
|
| 59 |
+
self.register_buffer("mask", None, persistent=False)
|
| 60 |
+
|
| 61 |
+
self.use_flash = use_flash
|
| 62 |
+
assert not (
|
| 63 |
+
use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
|
| 64 |
+
), "in order to use flash attention, you must be using pytorch 2.0 or above"
|
| 65 |
+
|
| 66 |
+
# determine efficient attention configs for cuda and cpu
|
| 67 |
+
self.config = namedtuple(
|
| 68 |
+
"EfficientAttentionConfig",
|
| 69 |
+
["enable_flash", "enable_math", "enable_mem_efficient"],
|
| 70 |
+
)
|
| 71 |
+
self.cpu_config = self.config(True, True, True)
|
| 72 |
+
self.cuda_config = None
|
| 73 |
+
|
| 74 |
+
if not torch.cuda.is_available() or not use_flash:
|
| 75 |
+
return
|
| 76 |
+
|
| 77 |
+
device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
|
| 78 |
+
|
| 79 |
+
if device_properties.major == 8 and device_properties.minor == 0:
|
| 80 |
+
print_once(
|
| 81 |
+
"A100 GPU detected, using flash attention if input tensor is on cuda"
|
| 82 |
+
)
|
| 83 |
+
self.cuda_config = self.config(True, False, False)
|
| 84 |
+
else:
|
| 85 |
+
print_once(
|
| 86 |
+
"Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
|
| 87 |
+
)
|
| 88 |
+
self.cuda_config = self.config(False, True, True)
|
| 89 |
+
|
| 90 |
+
def get_mask(self, n, device):
|
| 91 |
+
if exists(self.mask) and self.mask.shape[-1] >= n:
|
| 92 |
+
return self.mask[:n, :n]
|
| 93 |
+
|
| 94 |
+
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
|
| 95 |
+
self.register_buffer("mask", mask, persistent=False)
|
| 96 |
+
return mask
|
| 97 |
+
|
| 98 |
+
def flash_attn(self, q, k, v, mask=None):
|
| 99 |
+
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
|
| 100 |
+
|
| 101 |
+
# Recommended for multi-query single-key-value attention by Tri Dao
|
| 102 |
+
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
|
| 103 |
+
|
| 104 |
+
if k.ndim == 3:
|
| 105 |
+
k = rearrange(k, "b ... -> b 1 ...").expand_as(q)
|
| 106 |
+
|
| 107 |
+
if v.ndim == 3:
|
| 108 |
+
v = rearrange(v, "b ... -> b 1 ...").expand_as(q)
|
| 109 |
+
|
| 110 |
+
# Check if mask exists and expand to compatible shape
|
| 111 |
+
# The mask is B L, so it would have to be expanded to B H N L
|
| 112 |
+
|
| 113 |
+
if exists(mask):
|
| 114 |
+
mask = rearrange(mask, "b j -> b 1 1 j")
|
| 115 |
+
mask = mask.expand(-1, heads, q_len, -1)
|
| 116 |
+
|
| 117 |
+
# Check if there is a compatible device for flash attention
|
| 118 |
+
|
| 119 |
+
config = self.cuda_config if is_cuda else self.cpu_config
|
| 120 |
+
|
| 121 |
+
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
|
| 122 |
+
|
| 123 |
+
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
| 124 |
+
out = F.scaled_dot_product_attention(
|
| 125 |
+
q,
|
| 126 |
+
k,
|
| 127 |
+
v,
|
| 128 |
+
attn_mask=mask,
|
| 129 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 130 |
+
is_causal=self.causal,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
return out
|
| 134 |
+
|
| 135 |
+
def forward(self, q, k, v, mask=None):
|
| 136 |
+
"""
|
| 137 |
+
einstein notation
|
| 138 |
+
b - batch
|
| 139 |
+
h - heads
|
| 140 |
+
n, i, j - sequence length (base sequence length, source, target)
|
| 141 |
+
d - feature dimension
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
n, device = q.shape[-2], q.device
|
| 145 |
+
|
| 146 |
+
scale = q.shape[-1] ** -0.5
|
| 147 |
+
|
| 148 |
+
if self.use_flash:
|
| 149 |
+
return self.flash_attn(q, k, v, mask=mask)
|
| 150 |
+
|
| 151 |
+
kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d"
|
| 152 |
+
|
| 153 |
+
# similarity
|
| 154 |
+
|
| 155 |
+
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
|
| 156 |
+
|
| 157 |
+
# key padding mask
|
| 158 |
+
|
| 159 |
+
if exists(mask):
|
| 160 |
+
mask = rearrange(mask, "b j -> b 1 1 j")
|
| 161 |
+
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
|
| 162 |
+
|
| 163 |
+
# causal mask
|
| 164 |
+
|
| 165 |
+
if self.causal:
|
| 166 |
+
causal_mask = self.get_mask(n, device)
|
| 167 |
+
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
|
| 168 |
+
|
| 169 |
+
# attention
|
| 170 |
+
|
| 171 |
+
attn = sim.softmax(dim=-1)
|
| 172 |
+
attn = self.attn_dropout(attn)
|
| 173 |
+
|
| 174 |
+
# aggregate values
|
| 175 |
+
|
| 176 |
+
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
|
| 177 |
+
|
| 178 |
+
return out
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def Sequential(*mods):
|
| 182 |
+
return nn.Sequential(*filter(exists, mods))
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def exists(x):
|
| 186 |
+
return x is not None
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def default(val, d):
|
| 190 |
+
if exists(val):
|
| 191 |
+
return val
|
| 192 |
+
return d() if callable(d) else d
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class RMSNorm(nn.Module):
|
| 196 |
+
def __init__(self, dim, scale=True, dim_cond=None):
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.cond = exists(dim_cond)
|
| 199 |
+
self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None
|
| 200 |
+
|
| 201 |
+
self.scale = dim**0.5
|
| 202 |
+
self.gamma = nn.Parameter(torch.ones(dim)) if scale else None
|
| 203 |
+
|
| 204 |
+
def forward(self, x, cond=None):
|
| 205 |
+
gamma = default(self.gamma, 1)
|
| 206 |
+
out = F.normalize(x, dim=-1) * self.scale * gamma
|
| 207 |
+
|
| 208 |
+
if not self.cond:
|
| 209 |
+
return out
|
| 210 |
+
|
| 211 |
+
assert exists(cond)
|
| 212 |
+
gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1)
|
| 213 |
+
gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta))
|
| 214 |
+
return out * gamma + beta
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class CausalConv1d(nn.Conv1d):
|
| 218 |
+
def __init__(self, *args, **kwargs):
|
| 219 |
+
super().__init__(*args, **kwargs)
|
| 220 |
+
(kernel_size,) = self.kernel_size
|
| 221 |
+
(dilation,) = self.dilation
|
| 222 |
+
(stride,) = self.stride
|
| 223 |
+
|
| 224 |
+
assert stride == 1
|
| 225 |
+
self.causal_padding = dilation * (kernel_size - 1)
|
| 226 |
+
|
| 227 |
+
def forward(self, x):
|
| 228 |
+
causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
| 229 |
+
return super().forward(causal_padded_x)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class GEGLU(nn.Module):
|
| 233 |
+
def forward(self, x):
|
| 234 |
+
x, gate = x.chunk(2, dim=-1)
|
| 235 |
+
return F.gelu(gate) * x
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def FeedForward(dim, mult=4, causal_conv=False):
|
| 239 |
+
dim_inner = int(dim * mult * 2 / 3)
|
| 240 |
+
|
| 241 |
+
conv = None
|
| 242 |
+
if causal_conv:
|
| 243 |
+
conv = nn.Sequential(
|
| 244 |
+
Rearrange("b n d -> b d n"),
|
| 245 |
+
CausalConv1d(dim_inner, dim_inner, 3),
|
| 246 |
+
Rearrange("b d n -> b n d"),
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
return Sequential(
|
| 250 |
+
nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim)
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class Attention(nn.Module):
|
| 255 |
+
def __init__(
|
| 256 |
+
self,
|
| 257 |
+
dim,
|
| 258 |
+
*,
|
| 259 |
+
dim_context=None,
|
| 260 |
+
causal=False,
|
| 261 |
+
dim_head=64,
|
| 262 |
+
heads=8,
|
| 263 |
+
dropout=0.0,
|
| 264 |
+
use_flash=False,
|
| 265 |
+
cross_attn_include_queries=False,
|
| 266 |
+
):
|
| 267 |
+
super().__init__()
|
| 268 |
+
self.scale = dim_head**-0.5
|
| 269 |
+
self.heads = heads
|
| 270 |
+
self.cross_attn_include_queries = cross_attn_include_queries
|
| 271 |
+
|
| 272 |
+
dim_inner = dim_head * heads
|
| 273 |
+
dim_context = default(dim_context, dim)
|
| 274 |
+
|
| 275 |
+
self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash)
|
| 276 |
+
self.to_q = nn.Linear(dim, dim_inner, bias=False)
|
| 277 |
+
self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False)
|
| 278 |
+
self.to_out = nn.Linear(dim_inner, dim, bias=False)
|
| 279 |
+
|
| 280 |
+
def forward(self, x, context=None, mask=None):
|
| 281 |
+
h, has_context = self.heads, exists(context)
|
| 282 |
+
|
| 283 |
+
context = default(context, x)
|
| 284 |
+
|
| 285 |
+
if has_context and self.cross_attn_include_queries:
|
| 286 |
+
context = torch.cat((x, context), dim=-2)
|
| 287 |
+
|
| 288 |
+
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
|
| 289 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
| 290 |
+
|
| 291 |
+
out = self.attend(q, k, v, mask=mask)
|
| 292 |
+
|
| 293 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 294 |
+
return self.to_out(out)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class PerceiverResampler(nn.Module):
|
| 298 |
+
def __init__(
|
| 299 |
+
self,
|
| 300 |
+
*,
|
| 301 |
+
dim,
|
| 302 |
+
depth=2,
|
| 303 |
+
dim_context=None,
|
| 304 |
+
num_latents=32,
|
| 305 |
+
dim_head=64,
|
| 306 |
+
heads=8,
|
| 307 |
+
ff_mult=4,
|
| 308 |
+
use_flash_attn=False,
|
| 309 |
+
):
|
| 310 |
+
super().__init__()
|
| 311 |
+
dim_context = default(dim_context, dim)
|
| 312 |
+
|
| 313 |
+
self.proj_context = (
|
| 314 |
+
nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity()
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
| 318 |
+
nn.init.normal_(self.latents, std=0.02)
|
| 319 |
+
|
| 320 |
+
self.layers = nn.ModuleList([])
|
| 321 |
+
for _ in range(depth):
|
| 322 |
+
self.layers.append(
|
| 323 |
+
nn.ModuleList(
|
| 324 |
+
[
|
| 325 |
+
Attention(
|
| 326 |
+
dim=dim,
|
| 327 |
+
dim_head=dim_head,
|
| 328 |
+
heads=heads,
|
| 329 |
+
use_flash=use_flash_attn,
|
| 330 |
+
cross_attn_include_queries=True,
|
| 331 |
+
),
|
| 332 |
+
FeedForward(dim=dim, mult=ff_mult),
|
| 333 |
+
]
|
| 334 |
+
)
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
self.norm = RMSNorm(dim)
|
| 338 |
+
|
| 339 |
+
def forward(self, x, mask=None):
|
| 340 |
+
batch = x.shape[0]
|
| 341 |
+
|
| 342 |
+
x = self.proj_context(x)
|
| 343 |
+
|
| 344 |
+
latents = repeat(self.latents, "n d -> b n d", b=batch)
|
| 345 |
+
|
| 346 |
+
for attn, ff in self.layers:
|
| 347 |
+
latents = attn(latents, x, mask=mask) + latents
|
| 348 |
+
latents = ff(latents) + latents
|
| 349 |
+
|
| 350 |
+
return self.norm(latents)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
if __name__ == "__main__":
|
| 354 |
+
model = PerceiverResampler(dim=256, dim_context=80)
|
| 355 |
+
x = torch.randn(8, 200, 80)
|
| 356 |
+
out = model(x)
|
| 357 |
+
print(out.shape) # [8, 32, 80]
|
| 358 |
+
|
| 359 |
+
num_params = sum(param.numel() for param in model.parameters())
|
| 360 |
+
print("{} M".format(num_params / 1e6))
|
models/bicodec_tokenizer/modules/speaker/pooling_layers.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 Shuai Wang (wsstriving@gmail.com)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Pooling functions to aggregate frame-level deep features
|
| 16 |
+
into segment-level speaker embeddings
|
| 17 |
+
|
| 18 |
+
High-order statistics are surprisingly effective, TSDP acts similarly as TSTP,
|
| 19 |
+
even though we remove the mean statistic, on Voxceleb.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class TAP(nn.Module):
|
| 28 |
+
"""
|
| 29 |
+
Temporal average pooling, only first-order mean is considered
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, in_dim=0, **kwargs):
|
| 33 |
+
super(TAP, self).__init__()
|
| 34 |
+
self.in_dim = in_dim
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
pooling_mean = x.mean(dim=-1)
|
| 38 |
+
# To be compatable with 2D input
|
| 39 |
+
pooling_mean = pooling_mean.flatten(start_dim=1)
|
| 40 |
+
return pooling_mean
|
| 41 |
+
|
| 42 |
+
def get_out_dim(self):
|
| 43 |
+
self.out_dim = self.in_dim
|
| 44 |
+
return self.out_dim
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class TSDP(nn.Module):
|
| 48 |
+
"""
|
| 49 |
+
Temporal standard deviation pooling, only second-order std is considered
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(self, in_dim=0, **kwargs):
|
| 53 |
+
super(TSDP, self).__init__()
|
| 54 |
+
self.in_dim = in_dim
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
# The last dimension is the temporal axis
|
| 58 |
+
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
|
| 59 |
+
pooling_std = pooling_std.flatten(start_dim=1)
|
| 60 |
+
return pooling_std
|
| 61 |
+
|
| 62 |
+
def get_out_dim(self):
|
| 63 |
+
self.out_dim = self.in_dim
|
| 64 |
+
return self.out_dim
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class TSTP(nn.Module):
|
| 68 |
+
"""
|
| 69 |
+
Temporal statistics pooling, concatenate mean and std, which is used in
|
| 70 |
+
x-vector
|
| 71 |
+
Comment: simple concatenation can not make full use of both statistics
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(self, in_dim=0, **kwargs):
|
| 75 |
+
super(TSTP, self).__init__()
|
| 76 |
+
self.in_dim = in_dim
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
# The last dimension is the temporal axis
|
| 80 |
+
pooling_mean = x.mean(dim=-1)
|
| 81 |
+
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
|
| 82 |
+
pooling_mean = pooling_mean.flatten(start_dim=1)
|
| 83 |
+
pooling_std = pooling_std.flatten(start_dim=1)
|
| 84 |
+
stats = torch.cat((pooling_mean, pooling_std), 1)
|
| 85 |
+
return stats
|
| 86 |
+
|
| 87 |
+
def get_out_dim(self):
|
| 88 |
+
self.out_dim = self.in_dim * 2
|
| 89 |
+
return self.out_dim
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class ASTP(nn.Module):
|
| 93 |
+
""" Attentive statistics pooling: Channel- and context-dependent
|
| 94 |
+
statistics pooling, first used in ECAPA_TDNN.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
def __init__(self,
|
| 98 |
+
in_dim,
|
| 99 |
+
bottleneck_dim=128,
|
| 100 |
+
global_context_att=False,
|
| 101 |
+
**kwargs):
|
| 102 |
+
super(ASTP, self).__init__()
|
| 103 |
+
self.in_dim = in_dim
|
| 104 |
+
self.global_context_att = global_context_att
|
| 105 |
+
|
| 106 |
+
# Use Conv1d with stride == 1 rather than Linear, then we don't
|
| 107 |
+
# need to transpose inputs.
|
| 108 |
+
if global_context_att:
|
| 109 |
+
self.linear1 = nn.Conv1d(
|
| 110 |
+
in_dim * 3, bottleneck_dim,
|
| 111 |
+
kernel_size=1) # equals W and b in the paper
|
| 112 |
+
else:
|
| 113 |
+
self.linear1 = nn.Conv1d(
|
| 114 |
+
in_dim, bottleneck_dim,
|
| 115 |
+
kernel_size=1) # equals W and b in the paper
|
| 116 |
+
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
|
| 117 |
+
kernel_size=1) # equals V and k in the paper
|
| 118 |
+
|
| 119 |
+
def forward(self, x):
|
| 120 |
+
"""
|
| 121 |
+
x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
|
| 122 |
+
or a 4-dimensional tensor in resnet architecture (B,C,F,T)
|
| 123 |
+
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
|
| 124 |
+
"""
|
| 125 |
+
if len(x.shape) == 4:
|
| 126 |
+
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
|
| 127 |
+
assert len(x.shape) == 3
|
| 128 |
+
|
| 129 |
+
if self.global_context_att:
|
| 130 |
+
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
| 131 |
+
context_std = torch.sqrt(
|
| 132 |
+
torch.var(x, dim=-1, keepdim=True) + 1e-7).expand_as(x)
|
| 133 |
+
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
| 134 |
+
else:
|
| 135 |
+
x_in = x
|
| 136 |
+
|
| 137 |
+
# DON'T use ReLU here! ReLU may be hard to converge.
|
| 138 |
+
alpha = torch.tanh(
|
| 139 |
+
self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
|
| 140 |
+
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
| 141 |
+
mean = torch.sum(alpha * x, dim=2)
|
| 142 |
+
var = torch.sum(alpha * (x**2), dim=2) - mean**2
|
| 143 |
+
std = torch.sqrt(var.clamp(min=1e-7))
|
| 144 |
+
return torch.cat([mean, std], dim=1)
|
| 145 |
+
|
| 146 |
+
def get_out_dim(self):
|
| 147 |
+
self.out_dim = 2 * self.in_dim
|
| 148 |
+
return self.out_dim
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class MHASTP(torch.nn.Module):
|
| 152 |
+
""" Multi head attentive statistics pooling
|
| 153 |
+
Reference:
|
| 154 |
+
Self Multi-Head Attention for Speaker Recognition
|
| 155 |
+
https://arxiv.org/pdf/1906.09890.pdf
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
def __init__(self,
|
| 159 |
+
in_dim,
|
| 160 |
+
layer_num=2,
|
| 161 |
+
head_num=2,
|
| 162 |
+
d_s=1,
|
| 163 |
+
bottleneck_dim=64,
|
| 164 |
+
**kwargs):
|
| 165 |
+
super(MHASTP, self).__init__()
|
| 166 |
+
assert (in_dim % head_num
|
| 167 |
+
) == 0 # make sure that head num can be divided by input_dim
|
| 168 |
+
self.in_dim = in_dim
|
| 169 |
+
self.head_num = head_num
|
| 170 |
+
d_model = int(in_dim / head_num)
|
| 171 |
+
channel_dims = [bottleneck_dim for i in range(layer_num + 1)]
|
| 172 |
+
if d_s > 1:
|
| 173 |
+
d_s = d_model
|
| 174 |
+
else:
|
| 175 |
+
d_s = 1
|
| 176 |
+
self.d_s = d_s
|
| 177 |
+
channel_dims[0], channel_dims[-1] = d_model, d_s
|
| 178 |
+
heads_att_trans = []
|
| 179 |
+
for i in range(self.head_num):
|
| 180 |
+
att_trans = nn.Sequential()
|
| 181 |
+
for i in range(layer_num - 1):
|
| 182 |
+
att_trans.add_module(
|
| 183 |
+
'att_' + str(i),
|
| 184 |
+
nn.Conv1d(channel_dims[i], channel_dims[i + 1], 1, 1))
|
| 185 |
+
att_trans.add_module('tanh' + str(i), nn.Tanh())
|
| 186 |
+
att_trans.add_module(
|
| 187 |
+
'att_' + str(layer_num - 1),
|
| 188 |
+
nn.Conv1d(channel_dims[layer_num - 1], channel_dims[layer_num],
|
| 189 |
+
1, 1))
|
| 190 |
+
heads_att_trans.append(att_trans)
|
| 191 |
+
self.heads_att_trans = nn.ModuleList(heads_att_trans)
|
| 192 |
+
|
| 193 |
+
def forward(self, input):
|
| 194 |
+
"""
|
| 195 |
+
input: a 3-dimensional tensor in xvector architecture
|
| 196 |
+
or a 4-dimensional tensor in resnet architecture
|
| 197 |
+
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
|
| 198 |
+
"""
|
| 199 |
+
if len(input.shape) == 4: # B x F x T
|
| 200 |
+
input = input.reshape(input.shape[0],
|
| 201 |
+
input.shape[1] * input.shape[2],
|
| 202 |
+
input.shape[3])
|
| 203 |
+
assert len(input.shape) == 3
|
| 204 |
+
bs, f_dim, t_dim = input.shape
|
| 205 |
+
chunks = torch.chunk(input, self.head_num, 1)
|
| 206 |
+
# split
|
| 207 |
+
chunks_out = []
|
| 208 |
+
# for i in range(self.head_num):
|
| 209 |
+
# att_score = self.heads_att_trans[i](chunks[i])
|
| 210 |
+
for i, layer in enumerate(self.heads_att_trans):
|
| 211 |
+
att_score = layer(chunks[i])
|
| 212 |
+
alpha = F.softmax(att_score, dim=-1)
|
| 213 |
+
mean = torch.sum(alpha * chunks[i], dim=2)
|
| 214 |
+
var = torch.sum(alpha * chunks[i]**2, dim=2) - mean**2
|
| 215 |
+
std = torch.sqrt(var.clamp(min=1e-7))
|
| 216 |
+
chunks_out.append(torch.cat((mean, std), dim=1))
|
| 217 |
+
out = torch.cat(chunks_out, dim=1)
|
| 218 |
+
return out
|
| 219 |
+
|
| 220 |
+
def get_out_dim(self):
|
| 221 |
+
self.out_dim = 2 * self.in_dim
|
| 222 |
+
return self.out_dim
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class MQMHASTP(torch.nn.Module):
|
| 226 |
+
""" An attentive pooling
|
| 227 |
+
Reference:
|
| 228 |
+
multi query multi head attentive statistics pooling
|
| 229 |
+
https://arxiv.org/pdf/2110.05042.pdf
|
| 230 |
+
Args:
|
| 231 |
+
in_dim: the feature dimension of input
|
| 232 |
+
layer_num: the number of layer in the pooling layer
|
| 233 |
+
query_num: the number of querys
|
| 234 |
+
head_num: the number of heads
|
| 235 |
+
bottleneck_dim: the bottleneck dimension
|
| 236 |
+
|
| 237 |
+
SA (H = 1, Q = 1, n = 2, d_s = 1) ref:
|
| 238 |
+
https://www.danielpovey.com/files/2018_interspeech_xvector_attention.pdf
|
| 239 |
+
MHA (H > 1, Q = 1, n = 1, d_s = 1) ref:
|
| 240 |
+
https://arxiv.org/pdf/1906.09890.pdf
|
| 241 |
+
AS (H = 1, Q > 1, n = 2, d_s = 1) ref:
|
| 242 |
+
https://arxiv.org/pdf/1803.10963.pdf
|
| 243 |
+
VSA (H = 1, Q > 1, n = 2, d_s = d_h) ref:
|
| 244 |
+
http://www.interspeech2020.org/uploadfile/pdf/Mon-2-10-5.pdf
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
def __init__(self,
|
| 248 |
+
in_dim,
|
| 249 |
+
layer_num=2,
|
| 250 |
+
query_num=2,
|
| 251 |
+
head_num=8,
|
| 252 |
+
d_s=2,
|
| 253 |
+
bottleneck_dim=64,
|
| 254 |
+
**kwargs):
|
| 255 |
+
super(MQMHASTP, self).__init__()
|
| 256 |
+
self.n_query = nn.ModuleList([
|
| 257 |
+
MHASTP(in_dim,
|
| 258 |
+
layer_num=layer_num,
|
| 259 |
+
head_num=head_num,
|
| 260 |
+
d_s=d_s,
|
| 261 |
+
bottleneck_dim=bottleneck_dim) for i in range(query_num)
|
| 262 |
+
])
|
| 263 |
+
self.query_num = query_num
|
| 264 |
+
self.in_dim = in_dim
|
| 265 |
+
|
| 266 |
+
def forward(self, input):
|
| 267 |
+
"""
|
| 268 |
+
input: a 3-dimensional tensor in xvector architecture
|
| 269 |
+
or a 4-dimensional tensor in resnet architecture
|
| 270 |
+
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
|
| 271 |
+
"""
|
| 272 |
+
if len(input.shape) == 4: # B x F x T
|
| 273 |
+
input = input.reshape(input.shape[0],
|
| 274 |
+
input.shape[1] * input.shape[2],
|
| 275 |
+
input.shape[3])
|
| 276 |
+
assert len(input.shape) == 3
|
| 277 |
+
res = []
|
| 278 |
+
for i, layer in enumerate(self.n_query):
|
| 279 |
+
res.append(layer(input))
|
| 280 |
+
out = torch.cat(res, dim=-1)
|
| 281 |
+
return out
|
| 282 |
+
|
| 283 |
+
def get_out_dim(self):
|
| 284 |
+
self.out_dim = self.in_dim * 2 * self.query_num
|
| 285 |
+
return self.out_dim
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
if __name__ == '__main__':
|
| 289 |
+
data = torch.randn(16, 512, 10, 35)
|
| 290 |
+
# model = StatisticsPooling()
|
| 291 |
+
model = MQMHASTP(512 * 10)
|
| 292 |
+
model = MHASTP(512 * 10)
|
| 293 |
+
model = MQMHASTP(512 * 10, context=False)
|
| 294 |
+
print(model)
|
| 295 |
+
|
| 296 |
+
out = model(data)
|
| 297 |
+
print(out.shape)
|
| 298 |
+
print(model.get_out_dim())
|
models/bicodec_tokenizer/modules/speaker/speaker_encoder.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SparkAudio
|
| 2 |
+
# 2025 Xinsheng Wang (w.xinshawn@gmail.com)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
|
| 19 |
+
from typing import List, Tuple
|
| 20 |
+
from ..fsq.residual_fsq import ResidualFSQ
|
| 21 |
+
from .ecapa_tdnn import ECAPA_TDNN_GLOB_c512
|
| 22 |
+
from .perceiver_encoder import PerceiverResampler
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
x-vector + d-vector
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class SpeakerEncoder(nn.Module):
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
input_dim (int): acoustic feature dimension
|
| 34 |
+
out_dim (int): output dimension of x-vector and d-vector
|
| 35 |
+
latent_dim (int): latent dimension before quantization
|
| 36 |
+
token_num (int): sequence length of speaker tokens
|
| 37 |
+
fsq_levels (List[int]): number of levels for each quantizer
|
| 38 |
+
fsq_num_quantizers (int): number of quantizers
|
| 39 |
+
|
| 40 |
+
Return:
|
| 41 |
+
speaker_embs: (B, T2, out_dim)
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
input_dim: int = 100,
|
| 47 |
+
out_dim: int = 512,
|
| 48 |
+
latent_dim: int = 128,
|
| 49 |
+
token_num: int = 32,
|
| 50 |
+
fsq_levels: List[int] = [4, 4, 4, 4, 4, 4],
|
| 51 |
+
fsq_num_quantizers: int = 1,
|
| 52 |
+
):
|
| 53 |
+
super(SpeakerEncoder, self).__init__()
|
| 54 |
+
|
| 55 |
+
self.speaker_encoder = ECAPA_TDNN_GLOB_c512(
|
| 56 |
+
feat_dim=input_dim, embed_dim=out_dim
|
| 57 |
+
)
|
| 58 |
+
self.perceiver_sampler = PerceiverResampler(
|
| 59 |
+
dim=latent_dim, dim_context=512 * 3, num_latents=token_num
|
| 60 |
+
)
|
| 61 |
+
self.quantizer = ResidualFSQ(
|
| 62 |
+
levels=fsq_levels,
|
| 63 |
+
num_quantizers=fsq_num_quantizers,
|
| 64 |
+
dim=latent_dim,
|
| 65 |
+
is_channel_first=True,
|
| 66 |
+
quantize_dropout=False,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
self.project = nn.Linear(latent_dim * token_num, out_dim)
|
| 70 |
+
|
| 71 |
+
def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor:
|
| 72 |
+
zq = self.quantizer.get_codes_from_indices(indices.transpose(1, 2))
|
| 73 |
+
return zq.transpose(1, 2)
|
| 74 |
+
|
| 75 |
+
def get_indices(self, mels: torch.Tensor) -> torch.Tensor:
|
| 76 |
+
mels = mels.transpose(1, 2)
|
| 77 |
+
x = self.perceiver_sampler(mels).transpose(1, 2)
|
| 78 |
+
zq, indices = self.quantizer(x)
|
| 79 |
+
return indices
|
| 80 |
+
|
| 81 |
+
def forward(self, mels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 82 |
+
"""
|
| 83 |
+
Args:
|
| 84 |
+
mels: (B, D_mel, T1)
|
| 85 |
+
|
| 86 |
+
Return:
|
| 87 |
+
x_vector: (B, out_dim)
|
| 88 |
+
d_vector: (B, out_dim)
|
| 89 |
+
"""
|
| 90 |
+
# mels = mels.transpose(1,2)
|
| 91 |
+
|
| 92 |
+
x_vector, features = self.speaker_encoder(mels, True)
|
| 93 |
+
x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2)
|
| 94 |
+
zq, indices = self.quantizer(x) # zq: (B, latent_dim, T2, latent_dim)
|
| 95 |
+
x = zq.reshape(zq.shape[0], -1)
|
| 96 |
+
d_vector = self.project(x)
|
| 97 |
+
|
| 98 |
+
return x_vector, d_vector
|
| 99 |
+
|
| 100 |
+
def tokenize(self, mels: torch.Tensor) -> torch.Tensor:
|
| 101 |
+
"""tokenize the input mel spectrogram"""
|
| 102 |
+
_, features = self.speaker_encoder(mels, True)
|
| 103 |
+
x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2)
|
| 104 |
+
zq, indices = self.quantizer(x)
|
| 105 |
+
return indices
|
| 106 |
+
|
| 107 |
+
def detokenize(self, indices: torch.Tensor) -> torch.Tensor:
|
| 108 |
+
"""detokenize the input indices to d-vector"""
|
| 109 |
+
zq = self.quantizer.get_output_from_indices(indices.transpose(1, 2)).transpose(1, 2)
|
| 110 |
+
x = zq.reshape(zq.shape[0], -1)
|
| 111 |
+
d_vector = self.project(x)
|
| 112 |
+
return d_vector
|
| 113 |
+
|
| 114 |
+
if __name__ == "__main__":
|
| 115 |
+
model = SpeakerEncoder(
|
| 116 |
+
input_dim=100,
|
| 117 |
+
latent_dim=128,
|
| 118 |
+
token_num=32,
|
| 119 |
+
fsq_levels=[4, 4, 4, 4, 4, 4],
|
| 120 |
+
fsq_num_quantizers=1,
|
| 121 |
+
)
|
| 122 |
+
mel = torch.randn(8, 200, 100)
|
| 123 |
+
x_vector, d_vector = model(mel)
|
| 124 |
+
print("x-vector shape", x_vector.shape)
|
| 125 |
+
print("d-vector shape", d_vector.shape)
|
| 126 |
+
|
| 127 |
+
indices = model.tokenize(mel)
|
| 128 |
+
print("indices shape", indices.shape)
|
| 129 |
+
d_vector_post = model.detokenize(indices)
|
| 130 |
+
print("d-vector shape", d_vector_post.shape)
|
| 131 |
+
if d_vector_post.all() == d_vector.all():
|
| 132 |
+
print("d-vector post and d-vector are the same")
|
| 133 |
+
else:
|
| 134 |
+
print("d-vector post and d-vector are different")
|
| 135 |
+
num_params = sum(param.numel() for param in model.parameters())
|
| 136 |
+
print("{} M".format(num_params / 1e6))
|
models/bicodec_tokenizer/modules/vq/factorized_vector_quantize.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SparkAudio
|
| 2 |
+
# 2025 Xinsheng Wang (w.xinshawn@gmail.com)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
# Heavily based on https://github.com/lucidrains/vector-quantize-pytorch
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
from typing import Any, Dict
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
from einops import rearrange
|
| 25 |
+
from torch.nn.utils import weight_norm
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def WNConv1d(*args, **kwargs):
|
| 29 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def ema_inplace(moving_avg, new, decay):
|
| 33 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class FactorizedVectorQuantize(nn.Module):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
input_dim: int,
|
| 40 |
+
codebook_size: int,
|
| 41 |
+
codebook_dim: int,
|
| 42 |
+
commitment: float,
|
| 43 |
+
codebook_loss_weight: float = 1.0,
|
| 44 |
+
decay: float = 0.99,
|
| 45 |
+
threshold_ema_dead_code: float = 2,
|
| 46 |
+
momentum: float = 0.99,
|
| 47 |
+
**kwargs,
|
| 48 |
+
):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.input_dim = input_dim
|
| 51 |
+
self.codebook_size = codebook_size
|
| 52 |
+
self.codebook_dim = codebook_dim
|
| 53 |
+
self.commitment = commitment
|
| 54 |
+
self.codebook_loss_weight = codebook_loss_weight
|
| 55 |
+
self.decay = decay
|
| 56 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 57 |
+
self.momentum = momentum
|
| 58 |
+
|
| 59 |
+
if input_dim != self.codebook_dim:
|
| 60 |
+
self.in_project = WNConv1d(input_dim, self.codebook_dim, kernel_size=1)
|
| 61 |
+
self.out_project = WNConv1d(self.codebook_dim, input_dim, kernel_size=1)
|
| 62 |
+
|
| 63 |
+
else:
|
| 64 |
+
self.in_project = nn.Identity()
|
| 65 |
+
self.out_project = nn.Identity()
|
| 66 |
+
|
| 67 |
+
self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
|
| 68 |
+
self.register_buffer("cluster_size", torch.zeros(self.codebook_size))
|
| 69 |
+
|
| 70 |
+
def forward(self, z: torch.Tensor) -> Dict[str, Any]:
|
| 71 |
+
"""Quantized the input tensor using a fixed codebook and returns
|
| 72 |
+
the corresponding codebook vectors
|
| 73 |
+
|
| 74 |
+
Parameters
|
| 75 |
+
----------
|
| 76 |
+
z : Tensor[B x D x T]
|
| 77 |
+
|
| 78 |
+
Returns
|
| 79 |
+
-------
|
| 80 |
+
Tensor[B x D x T]
|
| 81 |
+
Quantized continuous representation of input
|
| 82 |
+
Tensor[1]
|
| 83 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
| 84 |
+
entries
|
| 85 |
+
Tensor[1]
|
| 86 |
+
Codebook loss to update the codebook
|
| 87 |
+
Tensor[B x T]
|
| 88 |
+
Codebook indices (quantized discrete representation of input)
|
| 89 |
+
Tensor[B x D x T]
|
| 90 |
+
Projected latents (continuous representation of input before quantization)
|
| 91 |
+
"""
|
| 92 |
+
# transpose since we use linear
|
| 93 |
+
|
| 94 |
+
# Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
|
| 95 |
+
z_e = self.in_project(z)
|
| 96 |
+
z_q, indices, dists = self.decode_latents(z_e)
|
| 97 |
+
|
| 98 |
+
# statistic the usage of codes
|
| 99 |
+
embed_onehot = F.one_hot(indices, self.codebook_size).type(z_e.dtype)
|
| 100 |
+
avg_probs = torch.mean(embed_onehot.reshape(-1, self.codebook_size), dim=0)
|
| 101 |
+
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
| 102 |
+
|
| 103 |
+
active_num = (embed_onehot.sum(0).sum(0) > 0).sum()
|
| 104 |
+
if self.training:
|
| 105 |
+
# We do the expiry of code at that point as buffers are in sync
|
| 106 |
+
# and all the workers will take the same decision.
|
| 107 |
+
ema_inplace(self.cluster_size, embed_onehot.sum(0).sum(0), self.decay)
|
| 108 |
+
active_num = sum(self.cluster_size > self.threshold_ema_dead_code)
|
| 109 |
+
|
| 110 |
+
if self.training:
|
| 111 |
+
commit_loss = (
|
| 112 |
+
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
| 113 |
+
* self.commitment
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
codebook_loss = (
|
| 117 |
+
F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
| 118 |
+
* self.codebook_loss_weight
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
else:
|
| 122 |
+
commit_loss = torch.zeros(0, device=z.device)
|
| 123 |
+
codebook_loss = torch.zeros(0, device=z.device)
|
| 124 |
+
|
| 125 |
+
z_q = (
|
| 126 |
+
z_e + (z_q - z_e).detach()
|
| 127 |
+
) # noop in forward pass, straight-through gradient estimator in backward pass
|
| 128 |
+
|
| 129 |
+
z_q = self.out_project(z_q)
|
| 130 |
+
|
| 131 |
+
vq_loss = (commit_loss + codebook_loss).mean()
|
| 132 |
+
|
| 133 |
+
return {
|
| 134 |
+
"z_q": z_q,
|
| 135 |
+
"indices": indices,
|
| 136 |
+
"dists": dists,
|
| 137 |
+
"vq_loss": vq_loss,
|
| 138 |
+
"perplexity": perplexity,
|
| 139 |
+
"active_num": active_num.float(),
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
def vq2emb(self, vq, out_proj=True):
|
| 143 |
+
emb = self.embed_code(vq)
|
| 144 |
+
if out_proj:
|
| 145 |
+
emb = self.out_project(emb)
|
| 146 |
+
return emb
|
| 147 |
+
|
| 148 |
+
def tokenize(self, z: torch.Tensor) -> torch.Tensor:
|
| 149 |
+
"""tokenize the input tensor"""
|
| 150 |
+
z_e = self.in_project(z)
|
| 151 |
+
_, indices, _ = self.decode_latents(z_e)
|
| 152 |
+
return indices
|
| 153 |
+
|
| 154 |
+
def detokenize(self, indices):
|
| 155 |
+
"""detokenize the input indices"""
|
| 156 |
+
z_q = self.decode_code(indices)
|
| 157 |
+
z_q = self.out_project(z_q)
|
| 158 |
+
return z_q
|
| 159 |
+
|
| 160 |
+
def get_emb(self):
|
| 161 |
+
return self.codebook.weight
|
| 162 |
+
|
| 163 |
+
def embed_code(self, embed_id):
|
| 164 |
+
return F.embedding(embed_id, self.codebook.weight)
|
| 165 |
+
|
| 166 |
+
def decode_code(self, embed_id):
|
| 167 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
| 168 |
+
|
| 169 |
+
def decode_latents(self, latents):
|
| 170 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
| 171 |
+
codebook = self.codebook.weight
|
| 172 |
+
|
| 173 |
+
# L2 normalize encodings and codebook
|
| 174 |
+
encodings = F.normalize(encodings)
|
| 175 |
+
codebook = F.normalize(codebook)
|
| 176 |
+
|
| 177 |
+
# Compute euclidean distance between encodings and codebook,
|
| 178 |
+
# with L2 normalization, the distance is equal to cosine distance
|
| 179 |
+
dist = (
|
| 180 |
+
encodings.pow(2).sum(1, keepdim=True)
|
| 181 |
+
- 2 * encodings @ codebook.t()
|
| 182 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
| 183 |
+
)
|
| 184 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
| 185 |
+
z_q = self.decode_code(indices)
|
| 186 |
+
|
| 187 |
+
return z_q, indices, dist
|
models/bicodec_tokenizer/spark_detokenizer.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Time :2025/3/29 10:34
|
| 3 |
+
# Author :Hui Huang
|
| 4 |
+
import os
|
| 5 |
+
from typing import Literal
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from .base_model import SparkBaseModel
|
| 9 |
+
from .batch_processor import AsyncBatchEngine
|
| 10 |
+
from .tokenizer_utils import get_dtype
|
| 11 |
+
from .modules.encoder_decoder.feat_decoder import Decoder
|
| 12 |
+
from .modules.encoder_decoder.wave_generator import WaveGenerator
|
| 13 |
+
from .modules.speaker.speaker_encoder import SpeakerEncoder
|
| 14 |
+
from .modules.vq.factorized_vector_quantize import FactorizedVectorQuantize
|
| 15 |
+
|
| 16 |
+
__all__ = ["SparkDeTokenizer"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SparkDeTokenizerModel(SparkBaseModel):
|
| 20 |
+
def __init__(self, config):
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
self.quantizer = FactorizedVectorQuantize(**config["quantizer"])
|
| 24 |
+
self.prenet = Decoder(**config["prenet"])
|
| 25 |
+
self.decoder = WaveGenerator(**config["decoder"])
|
| 26 |
+
self.speaker_encoder = SpeakerEncoder(**config["speaker_encoder"])
|
| 27 |
+
|
| 28 |
+
@torch.no_grad()
|
| 29 |
+
def forward(
|
| 30 |
+
self,
|
| 31 |
+
semantic_tokens: torch.Tensor,
|
| 32 |
+
global_tokens: torch.Tensor
|
| 33 |
+
) -> torch.Tensor:
|
| 34 |
+
z_q = self.quantizer.detokenize(semantic_tokens)
|
| 35 |
+
d_vector = self.speaker_encoder.detokenize(global_tokens)
|
| 36 |
+
x = self.prenet(z_q, d_vector)
|
| 37 |
+
x = x + d_vector.unsqueeze(-1)
|
| 38 |
+
wav_recon = self.decoder(x)
|
| 39 |
+
return wav_recon.detach()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class SparkDeTokenizer:
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
model_path: str,
|
| 46 |
+
device: Literal["cpu", "cuda", "mps"] | str = "cpu",
|
| 47 |
+
batch_size: int = 32,
|
| 48 |
+
wait_timeout: float = 0.01):
|
| 49 |
+
self.device = torch.device(device)
|
| 50 |
+
self.model = SparkDeTokenizerModel.from_pretrained(model_path).to(self.device)
|
| 51 |
+
self.device_type = device
|
| 52 |
+
self.dtype = get_dtype(self.device_type)
|
| 53 |
+
self._batch_processor = AsyncBatchEngine(
|
| 54 |
+
processing_function=self.batch_detokenize_async,
|
| 55 |
+
batch_size=batch_size,
|
| 56 |
+
wait_timeout=wait_timeout
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
@torch.no_grad()
|
| 60 |
+
def detokenize(
|
| 61 |
+
self,
|
| 62 |
+
semantic_tokens: torch.Tensor,
|
| 63 |
+
global_tokens: torch.Tensor
|
| 64 |
+
) -> torch.Tensor:
|
| 65 |
+
with torch.amp.autocast(self.device_type, dtype=self.dtype):
|
| 66 |
+
output = self.model(
|
| 67 |
+
semantic_tokens.to(self.device),
|
| 68 |
+
global_tokens.to(self.device)
|
| 69 |
+
)
|
| 70 |
+
return output
|
| 71 |
+
|
| 72 |
+
async def batch_detokenize_async(self, requests: list[dict[str, torch.Tensor]]) -> list[dict[str, torch.Tensor]]:
|
| 73 |
+
semantic_tokens, global_tokens = [], []
|
| 74 |
+
lengths = []
|
| 75 |
+
for request in requests:
|
| 76 |
+
semantic_tokens.append(request["semantic_tokens"])
|
| 77 |
+
global_tokens.append(request["global_tokens"])
|
| 78 |
+
lengths.append(len(request['semantic_tokens']))
|
| 79 |
+
# Concatenate tokens for batch processing
|
| 80 |
+
global_tokens = torch.stack(global_tokens, dim=0)
|
| 81 |
+
semantic_tokens = torch.nn.utils.rnn.pad_sequence(
|
| 82 |
+
semantic_tokens, batch_first=True, padding_value=0
|
| 83 |
+
)
|
| 84 |
+
# print(f"tokenizer global_tokens shape is {global_tokens.shape}")
|
| 85 |
+
# print(f"tokenizer semantic_tokens shape is {semantic_tokens.shape}")
|
| 86 |
+
audios = self.detokenize(
|
| 87 |
+
semantic_tokens=semantic_tokens,
|
| 88 |
+
global_tokens=global_tokens
|
| 89 |
+
).detach().cpu()
|
| 90 |
+
# Prepare responses
|
| 91 |
+
responses = []
|
| 92 |
+
for i in range(len(requests)):
|
| 93 |
+
audio = audios[i, :, :(lengths[i] * 320)] # 大概一个token对应audio长度320
|
| 94 |
+
responses.append({
|
| 95 |
+
"audio": audio,
|
| 96 |
+
})
|
| 97 |
+
|
| 98 |
+
if self.device.type == "cuda":
|
| 99 |
+
torch.cuda.empty_cache()
|
| 100 |
+
return responses
|
| 101 |
+
|
| 102 |
+
async def detokenize_async(self, request: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
| 103 |
+
output = await self._batch_processor.add_request(
|
| 104 |
+
single_input=request
|
| 105 |
+
)
|
| 106 |
+
return output.get("feature")
|
models/bicodec_tokenizer/spark_tokenizer.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Time :2025/3/29 10:30
|
| 3 |
+
# Author :Hui Huang
|
| 4 |
+
import os
|
| 5 |
+
from typing import Literal, Optional, Tuple, Dict, Any, List, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torchaudio
|
| 9 |
+
import torchaudio.transforms as TT
|
| 10 |
+
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor
|
| 11 |
+
import numpy as np
|
| 12 |
+
from loguru import logger
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
# ----------------- 假设这些模块位于你的项目路径下 -----------------
|
| 16 |
+
from .utils.file import load_config
|
| 17 |
+
from .utils.audio import load_audio
|
| 18 |
+
from .models.bicodec import BiCodec
|
| 19 |
+
from .base_model import SparkBaseModel
|
| 20 |
+
from .batch_processor import AsyncBatchEngine
|
| 21 |
+
# ---------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
__all__ = ["SparkTokenizer"]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class SparkTokenizer:
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
model_path: str,
|
| 30 |
+
device: Literal["cpu", "cuda", "mps"] | str = "cuda",
|
| 31 |
+
attn_implementation: Optional[Literal["sdpa", "flash_attention_2", "eager"]] = "eager",
|
| 32 |
+
batch_size: int = 32,
|
| 33 |
+
wait_timeout: float = 0.01,
|
| 34 |
+
):
|
| 35 |
+
self.device = torch.device(device)
|
| 36 |
+
self.model_dir = Path(model_path)
|
| 37 |
+
|
| 38 |
+
# 1. 加载配置
|
| 39 |
+
self.config = load_config(self.model_dir / "config.yaml")
|
| 40 |
+
self.device_type = "cuda" if "cuda" in str(device) else "cpu"
|
| 41 |
+
self.dtype = torch.float16 if self.device_type == "cuda" else torch.float32
|
| 42 |
+
self.target_sample_rate = self.config.get("sample_rate", 16000)
|
| 43 |
+
|
| 44 |
+
# 2. 加载模型
|
| 45 |
+
wav2vec_path = self.model_dir / "wav2vec2-large-xlsr-53"
|
| 46 |
+
self.processor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_path)
|
| 47 |
+
self.feature_extractor = Wav2Vec2Model.from_pretrained(
|
| 48 |
+
wav2vec_path,
|
| 49 |
+
attn_implementation=attn_implementation,
|
| 50 |
+
torch_dtype=self.dtype
|
| 51 |
+
)
|
| 52 |
+
self.feature_extractor.config.output_hidden_states = True
|
| 53 |
+
self.feature_extractor.to(self.device)
|
| 54 |
+
self.feature_extractor.eval()
|
| 55 |
+
|
| 56 |
+
# BiCodec model
|
| 57 |
+
self.model = (
|
| 58 |
+
BiCodec.load_from_checkpoint(str(self.model_dir)).to(self.device).half()
|
| 59 |
+
)
|
| 60 |
+
self.model.eval()
|
| 61 |
+
|
| 62 |
+
# 异步处理引擎
|
| 63 |
+
self._batch_processor = AsyncBatchEngine(
|
| 64 |
+
processing_function=self.batch_tokenize_async,
|
| 65 |
+
batch_size=batch_size,
|
| 66 |
+
wait_timeout=wait_timeout
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def _to_ndarray(self, audio_input: Union[str, Path, torch.Tensor]) -> np.ndarray:
|
| 70 |
+
"""
|
| 71 |
+
将输入(路径或Tensor)统一转换为指定采样率的 numpy 数组。
|
| 72 |
+
"""
|
| 73 |
+
if isinstance(audio_input, (str, Path)):
|
| 74 |
+
# 如果是路径,直接使用原有的 load_audio
|
| 75 |
+
wav = load_audio(
|
| 76 |
+
str(audio_input),
|
| 77 |
+
sampling_rate=self.target_sample_rate,
|
| 78 |
+
volume_normalize=self.config.get("volume_normalize", True),
|
| 79 |
+
)
|
| 80 |
+
elif isinstance(audio_input, torch.Tensor):
|
| 81 |
+
# 如果是 Tensor
|
| 82 |
+
wav = audio_input.detach().cpu().float()
|
| 83 |
+
|
| 84 |
+
# 处理通道: [C, T] -> [T]
|
| 85 |
+
if wav.ndim > 1:
|
| 86 |
+
wav = torch.mean(wav, dim=0)
|
| 87 |
+
|
| 88 |
+
# 这里默认输入的 Tensor 采样率已经是 self.target_sample_rate
|
| 89 |
+
# 如果需要在这里做重采样,需要额外传入输入采样率参数
|
| 90 |
+
wav = wav.numpy()
|
| 91 |
+
|
| 92 |
+
# 可选:音量归一化逻辑(如果 Tensor 没归一化)
|
| 93 |
+
if self.config.get("volume_normalize", True):
|
| 94 |
+
max_val = np.abs(wav).max()
|
| 95 |
+
if max_val > 0:
|
| 96 |
+
wav = wav / max_val * 0.9
|
| 97 |
+
else:
|
| 98 |
+
raise ValueError(f"Unsupported audio type: {type(audio_input)}")
|
| 99 |
+
|
| 100 |
+
return wav
|
| 101 |
+
|
| 102 |
+
def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
|
| 103 |
+
"""获取参考音频片段"""
|
| 104 |
+
ref_segment_length = (
|
| 105 |
+
int(self.target_sample_rate * self.config["ref_segment_duration"])
|
| 106 |
+
// self.config["latent_hop_length"]
|
| 107 |
+
* self.config["latent_hop_length"]
|
| 108 |
+
)
|
| 109 |
+
wav_length = len(wav)
|
| 110 |
+
|
| 111 |
+
if ref_segment_length > wav_length:
|
| 112 |
+
wav = np.tile(wav, ref_segment_length // wav_length + 1)
|
| 113 |
+
|
| 114 |
+
return wav[:ref_segment_length]
|
| 115 |
+
|
| 116 |
+
def process_audio(self, audio_input: Union[str, torch.Tensor], ref_audio_input: Union[str, torch.Tensor] = None) -> Tuple[np.ndarray, torch.Tensor]:
|
| 117 |
+
"""
|
| 118 |
+
处理音频和参考音频。
|
| 119 |
+
"""
|
| 120 |
+
wav = self._to_ndarray(audio_input)
|
| 121 |
+
|
| 122 |
+
if ref_audio_input is None:
|
| 123 |
+
wav_ref_np = self.get_ref_clip(wav)
|
| 124 |
+
else:
|
| 125 |
+
ref_wav = self._to_ndarray(ref_audio_input)
|
| 126 |
+
wav_ref_np = self.get_ref_clip(ref_wav)
|
| 127 |
+
|
| 128 |
+
wav_ref = torch.from_numpy(wav_ref_np).unsqueeze(0).float()
|
| 129 |
+
return wav, wav_ref
|
| 130 |
+
|
| 131 |
+
def extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor:
|
| 132 |
+
"""提取 wav2vec2 特征"""
|
| 133 |
+
# processor 期望是 list of numpy
|
| 134 |
+
inputs = self.processor(
|
| 135 |
+
[w.cpu().numpy() for w in wavs],
|
| 136 |
+
sampling_rate=16000,
|
| 137 |
+
return_tensors="pt",
|
| 138 |
+
padding=True,
|
| 139 |
+
).input_values
|
| 140 |
+
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
with torch.amp.autocast(self.device_type, dtype=self.dtype):
|
| 143 |
+
feat = self.feature_extractor(inputs.to(self.feature_extractor.device))
|
| 144 |
+
|
| 145 |
+
feats_mix = (
|
| 146 |
+
feat.hidden_states[11] + feat.hidden_states[14] + feat.hidden_states[16]
|
| 147 |
+
) / 3
|
| 148 |
+
|
| 149 |
+
return feats_mix
|
| 150 |
+
|
| 151 |
+
@torch.no_grad()
|
| 152 |
+
def tokenize(self, audios: List[Union[str, torch.Tensor]]):
|
| 153 |
+
"""
|
| 154 |
+
支持音频路径列表或 Tensor 列表。
|
| 155 |
+
"""
|
| 156 |
+
batch_wavs = []
|
| 157 |
+
batch_ref_wavs = []
|
| 158 |
+
|
| 159 |
+
for audio_item in audios:
|
| 160 |
+
wav, wav_ref = self.process_audio(audio_input=audio_item, ref_audio_input=audio_item)
|
| 161 |
+
batch_wavs.append(torch.from_numpy(wav).float())
|
| 162 |
+
batch_ref_wavs.append(wav_ref.squeeze(0))
|
| 163 |
+
|
| 164 |
+
# Padding wavs
|
| 165 |
+
wav_lengths = [len(w) for w in batch_wavs]
|
| 166 |
+
max_wav_len = max(wav_lengths)
|
| 167 |
+
padded_wavs = torch.zeros(len(batch_wavs), max_wav_len, dtype=self.dtype).to(self.device)
|
| 168 |
+
for i, w in enumerate(batch_wavs):
|
| 169 |
+
padded_wavs[i, :len(w)] = w.to(self.dtype)
|
| 170 |
+
|
| 171 |
+
# Padding ref_wavs
|
| 172 |
+
ref_lengths = [len(w) for w in batch_ref_wavs]
|
| 173 |
+
max_ref_len = max(ref_lengths)
|
| 174 |
+
padded_ref_wavs = torch.zeros(len(batch_ref_wavs), max_ref_len, dtype=self.dtype).to(self.device)
|
| 175 |
+
for i, w in enumerate(batch_ref_wavs):
|
| 176 |
+
padded_ref_wavs[i, :len(w)] = w.to(self.dtype)
|
| 177 |
+
|
| 178 |
+
# 提取特征
|
| 179 |
+
feats = self.extract_wav2vec2_features(padded_wavs)
|
| 180 |
+
|
| 181 |
+
batch = {
|
| 182 |
+
"wav": padded_wavs,
|
| 183 |
+
"ref_wav": padded_ref_wavs,
|
| 184 |
+
"feat": feats,
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
semantic_tokens, global_tokens = self.model.tokenize(batch)
|
| 188 |
+
|
| 189 |
+
if self.device.type == "cuda":
|
| 190 |
+
torch.cuda.empty_cache()
|
| 191 |
+
|
| 192 |
+
return {"semantic_tokens": semantic_tokens, "global_tokens": global_tokens}
|
| 193 |
+
|
| 194 |
+
async def batch_tokenize_async(self, audios: list) -> list[dict[str, torch.Tensor]]:
|
| 195 |
+
tokenized = self.tokenize(audios)
|
| 196 |
+
responses = []
|
| 197 |
+
for i in range(len(audios)):
|
| 198 |
+
responses.append({
|
| 199 |
+
"global_tokens": tokenized["global_tokens"][i],
|
| 200 |
+
"semantic_tokens": tokenized["semantic_tokens"][i]
|
| 201 |
+
})
|
| 202 |
+
return responses
|
| 203 |
+
|
| 204 |
+
async def tokenize_async(self, audio: Union[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
| 205 |
+
output = await self._batch_processor.add_request(
|
| 206 |
+
single_input=audio
|
| 207 |
+
)
|
| 208 |
+
return output
|
| 209 |
+
|
| 210 |
+
# ------------------------------------------------------------------
|
| 211 |
+
# 测试用例
|
| 212 |
+
# ------------------------------------------------------------------
|
| 213 |
+
if __name__ == "__main__":
|
| 214 |
+
# 配置你的模型路径
|
| 215 |
+
MODEL_DIR = "/data/yumu/model/ark_tts_v1"
|
| 216 |
+
|
| 217 |
+
# 初始化
|
| 218 |
+
# 注意:在没有真实环境时,这行会因为找不到文件报错,请在有环境的地方运行
|
| 219 |
+
tokenizer = SparkTokenizer(model_path=MODEL_DIR, device="cuda" if torch.cuda.is_available() else "cpu")
|
| 220 |
+
|
| 221 |
+
# 准备数据:一个是本地存在的 wav 路径,一个是构造的 Tensor
|
| 222 |
+
dummy_wav_path = "/data/yumu/arktts/dufu.wav"
|
| 223 |
+
# 构造一个 16kHz 的 2 秒音频 Tensor (假设模型要求16k)
|
| 224 |
+
import torchaudio
|
| 225 |
+
dummy_tensor, sr = torchaudio.load(dummy_wav_path)
|
| 226 |
+
|
| 227 |
+
# 1. 测试路径输入
|
| 228 |
+
print("Testing path input...")
|
| 229 |
+
if os.path.exists(dummy_wav_path):
|
| 230 |
+
res1 = tokenizer.tokenize([dummy_wav_path])
|
| 231 |
+
print(f"Path results: {res1['semantic_tokens'].shape}")
|
| 232 |
+
|
| 233 |
+
# 2. 测试 Tensor 输入
|
| 234 |
+
print("Testing tensor input...")
|
| 235 |
+
res2 = tokenizer.tokenize([dummy_tensor])
|
| 236 |
+
print(f"Tensor results: {res2['semantic_tokens'].shape}")
|
| 237 |
+
|
| 238 |
+
# 3. 测试混合输入 (List 包含 str 和 Tensor)
|
| 239 |
+
print("Testing mixed input...")
|
| 240 |
+
# 为了演示,我们传两个相同的 tensor
|
| 241 |
+
res3 = tokenizer.tokenize([dummy_tensor, dummy_tensor])
|
| 242 |
+
print(f"Mixed results: {res3['semantic_tokens'].shape}")
|
| 243 |
+
|
| 244 |
+
print("All tests passed!")
|
models/bicodec_tokenizer/tokenizer_utils.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Time :2025/3/29 10:27
|
| 3 |
+
# Author :Hui Huang
|
| 4 |
+
from omegaconf import OmegaConf, DictConfig
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def load_config(config_path: str) -> DictConfig:
|
| 9 |
+
"""Loads a configuration file and optionally merges it with a base configuration.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
config_path (Path): Path to the configuration file.
|
| 13 |
+
"""
|
| 14 |
+
# Load the initial configuration from the given path
|
| 15 |
+
config = OmegaConf.load(config_path)
|
| 16 |
+
|
| 17 |
+
# Check if there is a base configuration specified and merge if necessary
|
| 18 |
+
if config.get("base_config", None) is not None:
|
| 19 |
+
base_config = OmegaConf.load(config["base_config"])
|
| 20 |
+
config = OmegaConf.merge(base_config, config)
|
| 21 |
+
|
| 22 |
+
return config
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def gpu_supports_fp16() -> bool:
|
| 26 |
+
# 1. 确保 CUDA 可用
|
| 27 |
+
if not torch.cuda.is_available():
|
| 28 |
+
return False
|
| 29 |
+
|
| 30 |
+
# 2. 获取设备的 compute capability
|
| 31 |
+
major, minor = torch.cuda.get_device_capability()
|
| 32 |
+
|
| 33 |
+
# 3. 判断是否 >= 5.3
|
| 34 |
+
if major > 5 or (major == 5 and minor >= 3):
|
| 35 |
+
return True
|
| 36 |
+
else:
|
| 37 |
+
return False
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_dtype(device: str):
|
| 41 |
+
if device.startswith('cuda') and gpu_supports_fp16():
|
| 42 |
+
return torch.float16
|
| 43 |
+
else:
|
| 44 |
+
return torch.float32
|
models/bicodec_tokenizer/utils/__init__.py
ADDED
|
File without changes
|
models/bicodec_tokenizer/utils/audio.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SparkAudio
|
| 2 |
+
# 2025 Xinsheng Wang (w.xinshawn@gmail.com)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Description:
|
| 17 |
+
This script contains a collection of functions designed to handle various
|
| 18 |
+
audio processing.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import random
|
| 22 |
+
import soxr
|
| 23 |
+
import soundfile
|
| 24 |
+
import torch
|
| 25 |
+
import torchaudio
|
| 26 |
+
import numpy as np
|
| 27 |
+
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
from typing import Tuple
|
| 30 |
+
from numpy.lib.stride_tricks import sliding_window_view
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray:
|
| 34 |
+
"""
|
| 35 |
+
Normalize the volume of an audio signal.
|
| 36 |
+
|
| 37 |
+
Parameters:
|
| 38 |
+
audio (numpy array): Input audio signal array.
|
| 39 |
+
coeff (float): Target coefficient for normalization, default is 0.2.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
numpy array: The volume-normalized audio signal.
|
| 43 |
+
"""
|
| 44 |
+
# Sort the absolute values of the audio signal
|
| 45 |
+
temp = np.sort(np.abs(audio))
|
| 46 |
+
|
| 47 |
+
# If the maximum value is less than 0.1, scale the array to have a maximum of 0.1
|
| 48 |
+
if temp[-1] < 0.1:
|
| 49 |
+
scaling_factor = max(
|
| 50 |
+
temp[-1], 1e-3
|
| 51 |
+
) # Prevent division by zero with a small constant
|
| 52 |
+
audio = audio / scaling_factor * 0.1
|
| 53 |
+
|
| 54 |
+
# Filter out values less than 0.01 from temp
|
| 55 |
+
temp = temp[temp > 0.01]
|
| 56 |
+
L = temp.shape[0] # Length of the filtered array
|
| 57 |
+
|
| 58 |
+
# If there are fewer than or equal to 10 significant values, return the audio without further processing
|
| 59 |
+
if L <= 10:
|
| 60 |
+
return audio
|
| 61 |
+
|
| 62 |
+
# Compute the average of the top 10% to 1% of values in temp
|
| 63 |
+
volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)])
|
| 64 |
+
|
| 65 |
+
# Normalize the audio to the target coefficient level, clamping the scale factor between 0.1 and 10
|
| 66 |
+
audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10)
|
| 67 |
+
|
| 68 |
+
# Ensure the maximum absolute value in the audio does not exceed 1
|
| 69 |
+
max_value = np.max(np.abs(audio))
|
| 70 |
+
if max_value > 1:
|
| 71 |
+
audio = audio / max_value
|
| 72 |
+
|
| 73 |
+
return audio
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def load_audio(
|
| 77 |
+
adfile: Path,
|
| 78 |
+
sampling_rate: int = None,
|
| 79 |
+
length: int = None,
|
| 80 |
+
volume_normalize: bool = False,
|
| 81 |
+
segment_duration: int = None,
|
| 82 |
+
) -> np.ndarray:
|
| 83 |
+
r"""Load audio file with target sampling rate and lsength
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
adfile (Path): path to audio file.
|
| 87 |
+
sampling_rate (int, optional): target sampling rate. Defaults to None.
|
| 88 |
+
length (int, optional): target audio length. Defaults to None.
|
| 89 |
+
volume_normalize (bool, optional): whether perform volume normalization. Defaults to False.
|
| 90 |
+
segment_duration (int): random select a segment with duration of {segment_duration}s.
|
| 91 |
+
Defualt to None which means the whole audio will be used.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
audio (np.ndarray): audio
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
audio, sr = soundfile.read(adfile)
|
| 98 |
+
if len(audio.shape) > 1:
|
| 99 |
+
audio = audio[:, 0]
|
| 100 |
+
|
| 101 |
+
if sampling_rate is not None and sr != sampling_rate:
|
| 102 |
+
audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ")
|
| 103 |
+
sr = sampling_rate
|
| 104 |
+
|
| 105 |
+
if segment_duration is not None:
|
| 106 |
+
seg_length = int(sr * segment_duration)
|
| 107 |
+
audio = random_select_audio_segment(audio, seg_length)
|
| 108 |
+
|
| 109 |
+
# Audio volume normalize
|
| 110 |
+
if volume_normalize:
|
| 111 |
+
audio = audio_volume_normalize(audio)
|
| 112 |
+
# check the audio length
|
| 113 |
+
if length is not None:
|
| 114 |
+
assert abs(audio.shape[0] - length) < 1000
|
| 115 |
+
if audio.shape[0] > length:
|
| 116 |
+
audio = audio[:length]
|
| 117 |
+
else:
|
| 118 |
+
audio = np.pad(audio, (0, int(length - audio.shape[0])))
|
| 119 |
+
return audio
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray:
|
| 123 |
+
"""get an audio segment given the length
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
audio (np.ndarray):
|
| 127 |
+
length (int): audio length = sampling_rate * duration
|
| 128 |
+
"""
|
| 129 |
+
if audio.shape[0] < length:
|
| 130 |
+
audio = np.pad(audio, (0, int(length - audio.shape[0])))
|
| 131 |
+
start_index = random.randint(0, audio.shape[0] - length)
|
| 132 |
+
end_index = int(start_index + length)
|
| 133 |
+
|
| 134 |
+
return audio[start_index:end_index]
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def audio_highpass_filter(audio, sample_rate, highpass_cutoff_freq):
|
| 138 |
+
"""apply highpass fileter to audio
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
audio (np.ndarray):
|
| 142 |
+
sample_rate (ind):
|
| 143 |
+
highpass_cutoff_freq (int):
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
audio = torchaudio.functional.highpass_biquad(
|
| 147 |
+
torch.from_numpy(audio), sample_rate, cutoff_freq=highpass_cutoff_freq
|
| 148 |
+
)
|
| 149 |
+
return audio.numpy()
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def stft(
|
| 153 |
+
x: torch.Tensor,
|
| 154 |
+
fft_size: int,
|
| 155 |
+
hop_size: int,
|
| 156 |
+
win_length: int,
|
| 157 |
+
window: str,
|
| 158 |
+
use_complex: bool = False,
|
| 159 |
+
) -> torch.Tensor:
|
| 160 |
+
"""Perform STFT and convert to magnitude spectrogram.
|
| 161 |
+
Args:
|
| 162 |
+
x (Tensor): Input signal tensor (B, T).
|
| 163 |
+
fft_size (int): FFT size.
|
| 164 |
+
hop_size (int): Hop size.
|
| 165 |
+
win_length (int): Window length.
|
| 166 |
+
window (str): Window function type.
|
| 167 |
+
Returns:
|
| 168 |
+
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
x_stft = torch.stft(
|
| 172 |
+
x, fft_size, hop_size, win_length, window.to(x.device), return_complex=True
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# clamp is needed to avoid nan or inf
|
| 176 |
+
if not use_complex:
|
| 177 |
+
return torch.sqrt(
|
| 178 |
+
torch.clamp(x_stft.real**2 + x_stft.imag**2, min=1e-7, max=1e3)
|
| 179 |
+
).transpose(2, 1)
|
| 180 |
+
else:
|
| 181 |
+
res = torch.cat([x_stft.real.unsqueeze(1), x_stft.imag.unsqueeze(1)], dim=1)
|
| 182 |
+
res = res.transpose(2, 3) # [B, 2, T, F]
|
| 183 |
+
return res
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def detect_speech_boundaries(
|
| 187 |
+
wav: np.ndarray,
|
| 188 |
+
sample_rate: int,
|
| 189 |
+
window_duration: float = 0.1,
|
| 190 |
+
energy_threshold: float = 0.01,
|
| 191 |
+
margin_factor: int = 2
|
| 192 |
+
) -> Tuple[int, int]:
|
| 193 |
+
"""Detect the start and end points of speech in an audio signal using RMS energy.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
wav: Input audio signal array with values in [-1, 1]
|
| 197 |
+
sample_rate: Audio sample rate in Hz
|
| 198 |
+
window_duration: Duration of detection window in seconds
|
| 199 |
+
energy_threshold: RMS energy threshold for speech detection
|
| 200 |
+
margin_factor: Factor to determine extra margin around detected boundaries
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
tuple: (start_index, end_index) of speech segment
|
| 204 |
+
|
| 205 |
+
Raises:
|
| 206 |
+
ValueError: If the audio contains only silence
|
| 207 |
+
"""
|
| 208 |
+
window_size = int(window_duration * sample_rate)
|
| 209 |
+
margin = margin_factor * window_size
|
| 210 |
+
step_size = window_size // 10
|
| 211 |
+
|
| 212 |
+
# Create sliding windows using stride tricks to avoid loops
|
| 213 |
+
windows = sliding_window_view(wav, window_size)[::step_size]
|
| 214 |
+
|
| 215 |
+
# Calculate RMS energy for each window
|
| 216 |
+
energy = np.sqrt(np.mean(windows ** 2, axis=1))
|
| 217 |
+
speech_mask = energy >= energy_threshold
|
| 218 |
+
|
| 219 |
+
if not np.any(speech_mask):
|
| 220 |
+
raise ValueError("No speech detected in audio (only silence)")
|
| 221 |
+
|
| 222 |
+
start = max(0, np.argmax(speech_mask) * step_size - margin)
|
| 223 |
+
end = min(len(wav), (len(speech_mask) - 1 - np.argmax(speech_mask[::-1])) * step_size + margin)
|
| 224 |
+
|
| 225 |
+
return start, end
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def remove_silence_on_both_ends(
|
| 229 |
+
wav: np.ndarray,
|
| 230 |
+
sample_rate: int,
|
| 231 |
+
window_duration: float = 0.1,
|
| 232 |
+
volume_threshold: float = 0.01
|
| 233 |
+
) -> np.ndarray:
|
| 234 |
+
"""Remove silence from both ends of an audio signal.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
wav: Input audio signal array
|
| 238 |
+
sample_rate: Audio sample rate in Hz
|
| 239 |
+
window_duration: Duration of detection window in seconds
|
| 240 |
+
volume_threshold: Amplitude threshold for silence detection
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
np.ndarray: Audio signal with silence removed from both ends
|
| 244 |
+
|
| 245 |
+
Raises:
|
| 246 |
+
ValueError: If the audio contains only silence
|
| 247 |
+
"""
|
| 248 |
+
start, end = detect_speech_boundaries(
|
| 249 |
+
wav,
|
| 250 |
+
sample_rate,
|
| 251 |
+
window_duration,
|
| 252 |
+
volume_threshold
|
| 253 |
+
)
|
| 254 |
+
return wav[start:end]
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def hertz_to_mel(pitch: float) -> float:
|
| 259 |
+
"""
|
| 260 |
+
Converts a frequency from the Hertz scale to the Mel scale.
|
| 261 |
+
|
| 262 |
+
Parameters:
|
| 263 |
+
- pitch: float or ndarray
|
| 264 |
+
Frequency in Hertz.
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
- mel: float or ndarray
|
| 268 |
+
Frequency in Mel scale.
|
| 269 |
+
"""
|
| 270 |
+
mel = 2595 * np.log10(1 + pitch / 700)
|
| 271 |
+
return mel
|
models/bicodec_tokenizer/utils/file.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SparkAudio
|
| 2 |
+
# 2025 Xinsheng Wang (w.xinshawn@gmail.com)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Description:
|
| 17 |
+
This script contains a collection of functions designed to handle various
|
| 18 |
+
file reading and writing operations. It provides utilities to read from files,
|
| 19 |
+
write data to files, and perform file manipulation tasks.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
import os
|
| 24 |
+
import json
|
| 25 |
+
import json
|
| 26 |
+
import csv
|
| 27 |
+
|
| 28 |
+
from tqdm import tqdm
|
| 29 |
+
from typing import List, Dict, Any, Set, Union
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
from omegaconf import OmegaConf, DictConfig
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def resolve_symbolic_link(symbolic_link_path: Path) -> Path:
|
| 35 |
+
"""
|
| 36 |
+
Resolves the absolute path of a symbolic link.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
symbolic_link_path (Path): The path to the symbolic link.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Path: The absolute path that the symbolic link points to.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
link_directory = os.path.dirname(symbolic_link_path)
|
| 46 |
+
target_path_relative = os.readlink(symbolic_link_path)
|
| 47 |
+
return os.path.join(link_directory, target_path_relative)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def write_jsonl(metadata: List[dict], file_path: Path) -> None:
|
| 51 |
+
"""Writes a list of dictionaries to a JSONL file.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
metadata : List[dict]
|
| 55 |
+
A list of dictionaries, each representing a piece of meta.
|
| 56 |
+
file_path : Path
|
| 57 |
+
The file path to save the JSONL file
|
| 58 |
+
|
| 59 |
+
This function writes each dictionary in the list to a new line in the specified file.
|
| 60 |
+
"""
|
| 61 |
+
with open(file_path, "w", encoding="utf-8") as f:
|
| 62 |
+
for meta in tqdm(metadata, desc="writing jsonl"):
|
| 63 |
+
# Convert dictionary to JSON string and write it to the file with a newline
|
| 64 |
+
json_str = json.dumps(meta, ensure_ascii=False) + "\n"
|
| 65 |
+
f.write(json_str)
|
| 66 |
+
print(f"jsonl saved to {file_path}")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def read_jsonl(file_path: Path) -> List[dict]:
|
| 70 |
+
"""
|
| 71 |
+
Reads a JSONL file and returns a list of dictionaries.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
file_path : Path
|
| 75 |
+
The path to the JSONL file to be read.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
List[dict]
|
| 79 |
+
A list of dictionaries parsed from each line of the JSONL file.
|
| 80 |
+
"""
|
| 81 |
+
metadata = []
|
| 82 |
+
# Open the file for reading
|
| 83 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 84 |
+
# Split the file into lines
|
| 85 |
+
lines = f.read().splitlines()
|
| 86 |
+
# Process each line
|
| 87 |
+
for line in lines:
|
| 88 |
+
# Convert JSON string back to dictionary and append to list
|
| 89 |
+
meta = json.loads(line)
|
| 90 |
+
metadata.append(meta)
|
| 91 |
+
# Return the list of metadata
|
| 92 |
+
return metadata
|
| 93 |
+
|
| 94 |
+
def read_json_as_jsonl(file_path: Path) -> List[dict]:
|
| 95 |
+
metadata = []
|
| 96 |
+
with open(file_path, 'r', encoding='utf-8') as infile:
|
| 97 |
+
data = json.load(infile)
|
| 98 |
+
for k in sorted(data.keys()):
|
| 99 |
+
meta = {'index': k}
|
| 100 |
+
meta.update(data[k])
|
| 101 |
+
metadata.append(meta)
|
| 102 |
+
return metadata
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def decode_unicode_strings(meta: Dict[str, Any]) -> Dict[str, Any]:
|
| 107 |
+
processed_meta = {}
|
| 108 |
+
for k, v in meta.items():
|
| 109 |
+
if isinstance(v, str):
|
| 110 |
+
processed_meta[k] = v.encode("utf-8").decode("unicode_escape")
|
| 111 |
+
else:
|
| 112 |
+
processed_meta[k] = v
|
| 113 |
+
return processed_meta
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def load_config(config_path: Path) -> DictConfig:
|
| 117 |
+
"""Loads a configuration file and optionally merges it with a base configuration.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
config_path (Path): Path to the configuration file.
|
| 121 |
+
"""
|
| 122 |
+
# Load the initial configuration from the given path
|
| 123 |
+
config = OmegaConf.load(config_path)
|
| 124 |
+
|
| 125 |
+
# Check if there is a base configuration specified and merge if necessary
|
| 126 |
+
if config.get("base_config", None) is not None:
|
| 127 |
+
base_config = OmegaConf.load(config["base_config"])
|
| 128 |
+
config = OmegaConf.merge(base_config, config)
|
| 129 |
+
|
| 130 |
+
return config
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def jsonl_to_csv(jsonl_file_path: str, csv_file_path: str) -> None:
|
| 135 |
+
"""
|
| 136 |
+
Converts a JSONL file to a CSV file.
|
| 137 |
+
|
| 138 |
+
This function reads a JSONL file, determines all unique keys present in the file,
|
| 139 |
+
and writes the data to a CSV file with columns for all these keys.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
all_keys = set()
|
| 143 |
+
data_rows = []
|
| 144 |
+
|
| 145 |
+
# Read the JSONL file once to extract keys and collect data
|
| 146 |
+
with open(jsonl_file_path, 'r') as file:
|
| 147 |
+
for line in file:
|
| 148 |
+
data = json.loads(line.strip())
|
| 149 |
+
data_rows.append(data)
|
| 150 |
+
all_keys.update(data.keys())
|
| 151 |
+
|
| 152 |
+
# Convert the set of keys to a sorted list for consistent column order
|
| 153 |
+
sorted_keys = sorted(all_keys)
|
| 154 |
+
|
| 155 |
+
# Write the data to a CSV file
|
| 156 |
+
with open(csv_file_path, 'w', newline='') as csvfile:
|
| 157 |
+
writer = csv.DictWriter(csvfile, fieldnames=sorted_keys)
|
| 158 |
+
|
| 159 |
+
# Write the header row
|
| 160 |
+
writer.writeheader()
|
| 161 |
+
|
| 162 |
+
# Write each row of data
|
| 163 |
+
for data in data_rows:
|
| 164 |
+
writer.writerow(data)
|
| 165 |
+
|
| 166 |
+
print(f"CSV file has been created at {csv_file_path}")
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def save_metadata(data, filename, headers=None):
|
| 170 |
+
"""
|
| 171 |
+
Save metadata to a file.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
data (list of dict): Metadata to be saved.
|
| 175 |
+
filename (str): Name of the file to save the metadata.
|
| 176 |
+
headers (list of str): The order of column names to be saved; defaults to the keys from the first dictionary in data if not provided.
|
| 177 |
+
"""
|
| 178 |
+
# Set headers to keys from the first dictionary in data if not explicitly provided
|
| 179 |
+
if headers is None:
|
| 180 |
+
headers = list(data[0].keys())
|
| 181 |
+
|
| 182 |
+
with open(filename, "w", encoding="utf-8") as file:
|
| 183 |
+
# Write the headers to the file
|
| 184 |
+
file.write("|".join(headers) + "\n")
|
| 185 |
+
for entry in data:
|
| 186 |
+
# Retrieve values in the order of headers, replacing any '|' characters with a space to prevent formatting errors
|
| 187 |
+
formatted_values = [str(entry.get(key, "")).replace("|", " ") for key in headers]
|
| 188 |
+
# Write the formatted values to the file
|
| 189 |
+
file.write("|".join(formatted_values) + "\n")
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def read_metadata(filename, headers=None):
|
| 193 |
+
"""
|
| 194 |
+
Read metadata from a file.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
filename (str): The file from which to read the metadata.
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
list of dict: The metadata read from the file.
|
| 201 |
+
list of str: The headers used in the file.
|
| 202 |
+
"""
|
| 203 |
+
with open(filename, "r", encoding="utf-8") as file:
|
| 204 |
+
lines = file.readlines()
|
| 205 |
+
|
| 206 |
+
data = []
|
| 207 |
+
# Set headers from the first line of the file if not provided
|
| 208 |
+
if headers is None:
|
| 209 |
+
headers = lines[0].strip().split("|")
|
| 210 |
+
lines = lines[1:]
|
| 211 |
+
|
| 212 |
+
for line in lines:
|
| 213 |
+
line = line.strip()
|
| 214 |
+
# Skip empty lines
|
| 215 |
+
if not line:
|
| 216 |
+
continue
|
| 217 |
+
# Split the line by '|' and pair with headers to form a dictionary
|
| 218 |
+
entry_data = dict(zip(headers, line.split("|")))
|
| 219 |
+
data.append(entry_data)
|
| 220 |
+
|
| 221 |
+
return data, headers
|
models/bicodec_tokenizer/utils/parse_options.sh
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
|
| 4 |
+
# Arnab Ghoshal, Karel Vesely
|
| 5 |
+
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
| 13 |
+
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
| 14 |
+
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
| 15 |
+
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
| 16 |
+
# See the Apache 2 License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Parse command-line options.
|
| 21 |
+
# To be sourced by another script (as in ". parse_options.sh").
|
| 22 |
+
# Option format is: --option-name arg
|
| 23 |
+
# and shell variable "option_name" gets set to value "arg."
|
| 24 |
+
# The exception is --help, which takes no arguments, but prints the
|
| 25 |
+
# $help_message variable (if defined).
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
###
|
| 29 |
+
### The --config file options have lower priority to command line
|
| 30 |
+
### options, so we need to import them first...
|
| 31 |
+
###
|
| 32 |
+
|
| 33 |
+
# Now import all the configs specified by command-line, in left-to-right order
|
| 34 |
+
# for ((argpos=1; argpos<$#; argpos++)); do
|
| 35 |
+
# if [ "${!argpos}" == "--config" ]; then
|
| 36 |
+
# argpos_plus1=$((argpos+1))
|
| 37 |
+
# config=${!argpos_plus1}
|
| 38 |
+
# [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
|
| 39 |
+
# . $config # source the config file.
|
| 40 |
+
# fi
|
| 41 |
+
# done
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
###
|
| 45 |
+
### No we process the command line options
|
| 46 |
+
###
|
| 47 |
+
while true; do
|
| 48 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
| 49 |
+
case "$1" in
|
| 50 |
+
# If the enclosing script is called with --help option, print the help
|
| 51 |
+
# message and exit. Scripts should put help messages in $help_message
|
| 52 |
+
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
|
| 53 |
+
else printf "$help_message\n" 1>&2 ; fi;
|
| 54 |
+
exit 0 ;;
|
| 55 |
+
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
|
| 56 |
+
exit 1 ;;
|
| 57 |
+
# If the first command-line argument begins with "--" (e.g. --foo-bar),
|
| 58 |
+
# then work out the variable name as $name, which will equal "foo_bar".
|
| 59 |
+
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
|
| 60 |
+
# Next we test whether the variable in question is undefned-- if so it's
|
| 61 |
+
# an invalid option and we die. Note: $0 evaluates to the name of the
|
| 62 |
+
# enclosing script.
|
| 63 |
+
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
|
| 64 |
+
# is undefined. We then have to wrap this test inside "eval" because
|
| 65 |
+
# foo_bar is itself inside a variable ($name).
|
| 66 |
+
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
| 67 |
+
|
| 68 |
+
oldval="`eval echo \\$$name`";
|
| 69 |
+
# Work out whether we seem to be expecting a Boolean argument.
|
| 70 |
+
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
|
| 71 |
+
was_bool=true;
|
| 72 |
+
else
|
| 73 |
+
was_bool=false;
|
| 74 |
+
fi
|
| 75 |
+
|
| 76 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
| 77 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
| 78 |
+
eval $name=\"$2\";
|
| 79 |
+
|
| 80 |
+
# Check that Boolean-valued arguments are really Boolean.
|
| 81 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
| 82 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
| 83 |
+
exit 1;
|
| 84 |
+
fi
|
| 85 |
+
shift 2;
|
| 86 |
+
;;
|
| 87 |
+
*) break;
|
| 88 |
+
esac
|
| 89 |
+
done
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Check for an empty argument to the --cmd option, which can easily occur as a
|
| 93 |
+
# result of scripting errors.
|
| 94 |
+
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
true; # so this script returns exit code 0.
|
models/bicodec_tokenizer/utils/token_parser.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK_TOKEN_MAP = {
|
| 2 |
+
"vc": "<|task_vc|>",
|
| 3 |
+
"tts": "<|task_tts|>",
|
| 4 |
+
"asr": "<|task_asr|>",
|
| 5 |
+
"s2s": "<|task_s2s|>",
|
| 6 |
+
"t2s": "<|task_t2s|>",
|
| 7 |
+
"understand": "<|task_understand|>",
|
| 8 |
+
"caption": "<|task_cap|>",
|
| 9 |
+
"controllable_tts": "<|task_controllable_tts|>",
|
| 10 |
+
"prompt_tts": "<|task_prompt_tts|>",
|
| 11 |
+
"speech_edit": "<|task_edit|>",
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
LEVELS_MAP = {
|
| 15 |
+
"very_low": 0,
|
| 16 |
+
"low": 1,
|
| 17 |
+
"moderate": 2,
|
| 18 |
+
"high": 3,
|
| 19 |
+
"very_high": 4,
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
LEVELS_MAP_UI = {
|
| 23 |
+
1: 'very_low',
|
| 24 |
+
2: 'low',
|
| 25 |
+
3: 'moderate',
|
| 26 |
+
4: 'high',
|
| 27 |
+
5: 'very_high'
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
GENDER_MAP = {
|
| 31 |
+
"female": 0,
|
| 32 |
+
"male": 1,
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
AGE_MAP = {"Child": 0, "Teenager": 1, "Youth-Adult": 2, "Middle-aged": 3, "Elderly": 4}
|
| 36 |
+
|
| 37 |
+
EMO_MAP = {
|
| 38 |
+
"UNKNOWN": 0,
|
| 39 |
+
"NEUTRAL": 1,
|
| 40 |
+
"ANGRY": 2,
|
| 41 |
+
"HAPPY": 3,
|
| 42 |
+
"SAD": 4,
|
| 43 |
+
"FEARFUL": 5,
|
| 44 |
+
"DISGUSTED": 6,
|
| 45 |
+
"SURPRISED": 7,
|
| 46 |
+
"SARCASTIC": 8,
|
| 47 |
+
"EXCITED": 9,
|
| 48 |
+
"SLEEPY": 10,
|
| 49 |
+
"CONFUSED": 11,
|
| 50 |
+
"EMPHASIS": 12,
|
| 51 |
+
"LAUGHING": 13,
|
| 52 |
+
"SINGING": 14,
|
| 53 |
+
"WORRIED": 15,
|
| 54 |
+
"WHISPER": 16,
|
| 55 |
+
"ANXIOUS": 17,
|
| 56 |
+
"NO-AGREEMENT": 18,
|
| 57 |
+
"APOLOGETIC": 19,
|
| 58 |
+
"CONCERNED": 20,
|
| 59 |
+
"ENUNCIATED": 21,
|
| 60 |
+
"ASSERTIVE": 22,
|
| 61 |
+
"ENCOURAGING": 23,
|
| 62 |
+
"CONTEMPT": 24,
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class TokenParser:
|
| 67 |
+
"""Turn label to special token"""
|
| 68 |
+
|
| 69 |
+
def __init__(self):
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
"""Parse the attributes of a person."""
|
| 73 |
+
|
| 74 |
+
def __init__(self):
|
| 75 |
+
pass
|
| 76 |
+
|
| 77 |
+
@staticmethod
|
| 78 |
+
def age(age: str) -> str:
|
| 79 |
+
"""Turn age token."""
|
| 80 |
+
age_id = AGE_MAP[age]
|
| 81 |
+
return f"<|age_{age_id}|>"
|
| 82 |
+
|
| 83 |
+
@staticmethod
|
| 84 |
+
def gender(gender: str) -> str:
|
| 85 |
+
"""Turn gender token."""
|
| 86 |
+
gender_id = GENDER_MAP[gender]
|
| 87 |
+
return f"<|gender_{gender_id}|>"
|
| 88 |
+
|
| 89 |
+
@staticmethod
|
| 90 |
+
def mel_value(mel: int):
|
| 91 |
+
"""Turn special token of mel scale pitch."""
|
| 92 |
+
mel = max(0, int(mel))
|
| 93 |
+
mel = min(1000, int(mel))
|
| 94 |
+
return f"<|pitch_value_{mel}|>"
|
| 95 |
+
|
| 96 |
+
@staticmethod
|
| 97 |
+
def mel_level(level: str):
|
| 98 |
+
"""Turn special token of mel level."""
|
| 99 |
+
level_tag = LEVELS_MAP[level]
|
| 100 |
+
return f"<|pitch_label_{level_tag}|>"
|
| 101 |
+
|
| 102 |
+
@staticmethod
|
| 103 |
+
def pitch_var_value(pitch_std: int):
|
| 104 |
+
"""Turn special token of pitch_std value."""
|
| 105 |
+
assert isinstance(pitch_std, int)
|
| 106 |
+
pitch_std = max(0, int(pitch_std))
|
| 107 |
+
pitch_std = min(10, int(pitch_std))
|
| 108 |
+
return f"<|pitch_var_value_{pitch_std}|>"
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
def pitch_var_level(level: str):
|
| 112 |
+
"""Turn special token of pitch std level."""
|
| 113 |
+
level_tag = LEVELS_MAP[level]
|
| 114 |
+
return f"<|pitch_var_label_{level_tag}|>"
|
| 115 |
+
|
| 116 |
+
@staticmethod
|
| 117 |
+
def loudness_value(loudness: int):
|
| 118 |
+
"""Turn special toak of loudness value [0, 30]"""
|
| 119 |
+
assert loudness >= 0
|
| 120 |
+
loudness = max(0, int(loudness))
|
| 121 |
+
loudness = min(30, int(loudness))
|
| 122 |
+
return f"<|loudness_value_{loudness}|>"
|
| 123 |
+
|
| 124 |
+
@staticmethod
|
| 125 |
+
def loudness_level(level: str):
|
| 126 |
+
"""Turn special token of loudness level."""
|
| 127 |
+
level_tag = LEVELS_MAP[level]
|
| 128 |
+
return f"<|loudness_label_{level_tag}|>"
|
| 129 |
+
|
| 130 |
+
@staticmethod
|
| 131 |
+
def speed_value(speed: int):
|
| 132 |
+
"""Turn special token of speed value."""
|
| 133 |
+
speed = max(0, int(speed))
|
| 134 |
+
speed = min(10, int(speed))
|
| 135 |
+
return f"<|speed_value_{speed}|>"
|
| 136 |
+
|
| 137 |
+
@staticmethod
|
| 138 |
+
def speed_level(level: str):
|
| 139 |
+
"""Turn special token of speed level."""
|
| 140 |
+
level_tag = LEVELS_MAP[level]
|
| 141 |
+
return f"<|speed_label_{level_tag}|>"
|
| 142 |
+
|
| 143 |
+
@staticmethod
|
| 144 |
+
def task(task: str) -> str:
|
| 145 |
+
"""Turn special token of task."""
|
| 146 |
+
assert task in TASK_TOKEN_MAP.keys()
|
| 147 |
+
|
| 148 |
+
return TASK_TOKEN_MAP[task]
|
| 149 |
+
|
| 150 |
+
@staticmethod
|
| 151 |
+
def emotion(emotion: str):
|
| 152 |
+
emo_id = EMO_MAP[emotion]
|
| 153 |
+
|
| 154 |
+
return f"<|emotion_{emo_id}|>"
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# test
|
| 158 |
+
if __name__ == "__main__":
|
| 159 |
+
from transformers import AutoTokenizer
|
| 160 |
+
|
| 161 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 162 |
+
"/aifs4su/xinshengwang/code/StyleCraft/tokenizer/stylecraft-bicodec-pitch-loudness-speed-emotion-tokenizer"
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
tasks = ["tts", "tts", "understand", "controllable_tts", "prompt_tts"]
|
| 166 |
+
ages = ["Child", "Teenager", "Youth-Adult", "Middle-aged", "Elderly"]
|
| 167 |
+
genders = ["female", "female", "female", "male", "male"]
|
| 168 |
+
mels = [100, 200, 300, 400, 500]
|
| 169 |
+
mel_levels = ["very_low", "low", "moderate", "high", "very_high"]
|
| 170 |
+
loudnesses = [1, 10, 23, 19, 30]
|
| 171 |
+
loudness_levels = ["very_low", "low", "moderate", "high", "very_high"]
|
| 172 |
+
emotions = ["UNKNOWN", "NEUTRAL", "ANGRY", "HAPPY", "SAD"]
|
| 173 |
+
|
| 174 |
+
for i in range(5):
|
| 175 |
+
task = TokenParser.task(tasks[i])
|
| 176 |
+
age = TokenParser.age(ages[i])
|
| 177 |
+
gender = TokenParser.gender(genders[i])
|
| 178 |
+
mel = TokenParser.mel_value(mels[i])
|
| 179 |
+
mel_level = TokenParser.mel_level(mel_levels[i])
|
| 180 |
+
loudness = TokenParser.loudness_value(loudnesses[i])
|
| 181 |
+
loudness_level = TokenParser.loudness_level(loudness_levels[i])
|
| 182 |
+
emotion = TokenParser.emotion(emotions[i])
|
| 183 |
+
inputs = [task, age, gender, mel, mel_level, loudness, loudness_level, emotion]
|
| 184 |
+
inputs = "".join(inputs)
|
| 185 |
+
ids = tokenizer.encode(inputs, add_special_tokens=False)
|
| 186 |
+
print(ids)
|
| 187 |
+
print("decode", tokenizer.decode(ids))
|
models/glm_speech_tokenizer/__init__.py
ADDED
|
File without changes
|
models/glm_speech_tokenizer/batch_processor.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Time :2024/11/17 15:33
|
| 3 |
+
# Author :Hui Huang
|
| 4 |
+
import asyncio
|
| 5 |
+
import uuid
|
| 6 |
+
from typing import Callable, List, Any, Awaitable, Tuple
|
| 7 |
+
from asyncio import Queue
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BatchProcessor:
|
| 11 |
+
"""Batch Processor for handling asynchronous requests in batches.
|
| 12 |
+
|
| 13 |
+
This class manages a queue of requests and processes them in batches
|
| 14 |
+
using multiple worker tasks.
|
| 15 |
+
|
| 16 |
+
Attributes:
|
| 17 |
+
processing_function (Callable[[List[Any]], Awaitable[List[Any]]]):
|
| 18 |
+
The function used for processing requests in batches.
|
| 19 |
+
num_workers (int): The number of worker tasks to process requests.
|
| 20 |
+
batch_size (int): The maximum number of requests to process in a single batch.
|
| 21 |
+
request_queue (Queue): The queue holding incoming requests.
|
| 22 |
+
loop (asyncio.AbstractEventLoop): The event loop used to create worker tasks.
|
| 23 |
+
worker_tasks (List[asyncio.Task]): The list of worker tasks.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
processing_function: Callable[[List[Any]], Awaitable[List[Any]]],
|
| 29 |
+
num_workers: int,
|
| 30 |
+
batch_size: int,
|
| 31 |
+
wait_timeout: float = 0.05
|
| 32 |
+
) -> None:
|
| 33 |
+
"""Initialize the BatchProcessor with the given processing function, number of workers, and batch size.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
processing_function (Callable[[List[Any]], Awaitable[List[Any]]]):
|
| 37 |
+
The function used for processing requests in batches.
|
| 38 |
+
num_workers (int): The number of worker tasks to process requests.
|
| 39 |
+
batch_size (int): The maximum number of requests to process in a single batch.
|
| 40 |
+
"""
|
| 41 |
+
self.processing_function = processing_function
|
| 42 |
+
self.num_workers = num_workers
|
| 43 |
+
self.batch_size = batch_size
|
| 44 |
+
self.wait_timeout = wait_timeout
|
| 45 |
+
self.request_queue: Queue = Queue()
|
| 46 |
+
self.loop = asyncio.get_running_loop()
|
| 47 |
+
self.worker_tasks = [
|
| 48 |
+
self.loop.create_task(self.batch_processor(i)) for i in range(num_workers)
|
| 49 |
+
]
|
| 50 |
+
# Wait until all worker tasks are started
|
| 51 |
+
self.loop.create_task(self._log_workers_started())
|
| 52 |
+
|
| 53 |
+
async def _log_workers_started(self):
|
| 54 |
+
await asyncio.sleep(0) # Yield control to ensure workers have started
|
| 55 |
+
|
| 56 |
+
async def batch_processor(self, worker_id: int):
|
| 57 |
+
"""Worker task that processes requests from the queue in batches.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
worker_id (int): The identifier for the worker task.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
while True:
|
| 64 |
+
requests: List[Tuple[Any, asyncio.Future]] = []
|
| 65 |
+
try:
|
| 66 |
+
while len(requests) < self.batch_size:
|
| 67 |
+
request = await asyncio.wait_for(
|
| 68 |
+
self.request_queue.get(), timeout=self.wait_timeout
|
| 69 |
+
)
|
| 70 |
+
requests.append(request)
|
| 71 |
+
except asyncio.TimeoutError:
|
| 72 |
+
pass
|
| 73 |
+
|
| 74 |
+
if requests:
|
| 75 |
+
all_requests = [
|
| 76 |
+
req[0] for req in requests
|
| 77 |
+
] # Extract the actual input data from each request tuple
|
| 78 |
+
futures = [req[1] for req in requests] # Extract the futures to resolve
|
| 79 |
+
try:
|
| 80 |
+
results = await self.processing_function(all_requests)
|
| 81 |
+
|
| 82 |
+
for (future, result) in zip(futures, results):
|
| 83 |
+
future.set_result(result)
|
| 84 |
+
except Exception as e:
|
| 85 |
+
for future in futures:
|
| 86 |
+
future.set_exception(e)
|
| 87 |
+
|
| 88 |
+
async def add_request(self, single_input: Any):
|
| 89 |
+
"""Add a new request to the queue.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
single_input (Any): The input data for processing.
|
| 93 |
+
"""
|
| 94 |
+
# loop = asyncio.get_running_loop()
|
| 95 |
+
future = self.loop.create_future()
|
| 96 |
+
self.request_queue.put_nowait((single_input, future))
|
| 97 |
+
return future
|
| 98 |
+
|
| 99 |
+
async def shutdown(self):
|
| 100 |
+
"""Shutdown the batch processor by cancelling all worker tasks."""
|
| 101 |
+
for task in self.worker_tasks:
|
| 102 |
+
task.cancel()
|
| 103 |
+
try:
|
| 104 |
+
await task
|
| 105 |
+
except asyncio.CancelledError:
|
| 106 |
+
print("Worker task cancelled.")
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class AsyncBatchEngine:
|
| 110 |
+
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
processing_function: Callable[[List[Any]], Awaitable[List[Any]]],
|
| 114 |
+
batch_size: int = 32,
|
| 115 |
+
wait_timeout: float = 0.01,
|
| 116 |
+
):
|
| 117 |
+
"""
|
| 118 |
+
Initialize the AsyncBatchEngine with a processing function, number of workers, and batch size.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
processing_function (Callable[[List[Any]], Awaitable[List[Any]]]): The batch processing function.
|
| 122 |
+
batch_size (int): The maximum number of requests to process in a single batch.
|
| 123 |
+
"""
|
| 124 |
+
self._processing_function = processing_function
|
| 125 |
+
self._batch_size = batch_size
|
| 126 |
+
self._is_running = False
|
| 127 |
+
self._batch_processor = None
|
| 128 |
+
self._wait_timeout = wait_timeout
|
| 129 |
+
|
| 130 |
+
async def start(self):
|
| 131 |
+
"""Start the engine by initializing the batch processor and worker tasks."""
|
| 132 |
+
if self._is_running:
|
| 133 |
+
return
|
| 134 |
+
|
| 135 |
+
self._batch_processor = BatchProcessor(
|
| 136 |
+
processing_function=self._processing_function,
|
| 137 |
+
batch_size=self._batch_size,
|
| 138 |
+
wait_timeout=self._wait_timeout,
|
| 139 |
+
num_workers=1
|
| 140 |
+
)
|
| 141 |
+
self._is_running = True
|
| 142 |
+
|
| 143 |
+
async def stop(self):
|
| 144 |
+
"""Stop the engine by shutting down the batch processor and worker tasks."""
|
| 145 |
+
self._check_running()
|
| 146 |
+
self._is_running = False
|
| 147 |
+
if self._batch_processor is not None:
|
| 148 |
+
await self._batch_processor.shutdown()
|
| 149 |
+
|
| 150 |
+
def _check_running(self):
|
| 151 |
+
"""Check if the engine is running.
|
| 152 |
+
|
| 153 |
+
Raises:
|
| 154 |
+
ValueError: If the engine is not running.
|
| 155 |
+
"""
|
| 156 |
+
if not self._is_running:
|
| 157 |
+
raise ValueError(
|
| 158 |
+
"The engine is not running. "
|
| 159 |
+
"You must start the engine before using it."
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
async def add_request(self, single_input: Any, request_id: str = None) -> dict:
|
| 163 |
+
"""Asynchronously add a request to be processed.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
single_input (Any): The input data for processing.
|
| 167 |
+
request_id (str): Optional request identifier to avoid data mix-up.
|
| 168 |
+
|
| 169 |
+
Raises:
|
| 170 |
+
ValueError: If the engine is not running when this method is called.
|
| 171 |
+
"""
|
| 172 |
+
if not self._is_running:
|
| 173 |
+
await self.start()
|
| 174 |
+
|
| 175 |
+
if request_id is None:
|
| 176 |
+
request_id = str(uuid.uuid4()) # Assign a unique ID if not provided
|
| 177 |
+
future = await self._batch_processor.add_request(single_input=single_input) # type: ignore
|
| 178 |
+
result = await future
|
| 179 |
+
return dict(
|
| 180 |
+
request_id=request_id,
|
| 181 |
+
feature=result
|
| 182 |
+
)
|
models/glm_speech_tokenizer/configuration_whisper.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import WhisperConfig
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class WhisperVQConfig(WhisperConfig):
|
| 5 |
+
def __init__(self,
|
| 6 |
+
pooling_kernel_size=None,
|
| 7 |
+
pooling_type="max",
|
| 8 |
+
pooling_position=0,
|
| 9 |
+
quantize_vocab_size=None,
|
| 10 |
+
quantize_position=16,
|
| 11 |
+
quantize_commit_coefficient=0.25,
|
| 12 |
+
quantize_loss_scale=1.0,
|
| 13 |
+
quantize_ema_decay=None,
|
| 14 |
+
quantize_restart_interval=None,
|
| 15 |
+
quantize_encoder_only=False,
|
| 16 |
+
quantize_causal_encoder=False,
|
| 17 |
+
quantize_causal_block_size=None,
|
| 18 |
+
skip_language_detection=False,
|
| 19 |
+
encoder_causal_attention=False,
|
| 20 |
+
encoder_causal_convolution=False,
|
| 21 |
+
**kwargs):
|
| 22 |
+
self.pooling_kernel_size = pooling_kernel_size
|
| 23 |
+
self.pooling_type = pooling_type
|
| 24 |
+
self.pooling_position = pooling_position
|
| 25 |
+
self.quantize_vocab_size = quantize_vocab_size
|
| 26 |
+
self.quantize_position = quantize_position
|
| 27 |
+
self.quantize_commit_coefficient = quantize_commit_coefficient
|
| 28 |
+
self.quantize_loss_scale = quantize_loss_scale
|
| 29 |
+
self.quantize_ema_decay = quantize_ema_decay
|
| 30 |
+
self.quantize_restart_interval = quantize_restart_interval
|
| 31 |
+
self.quantize_encoder_only = quantize_encoder_only
|
| 32 |
+
self.quantize_causal_encoder = quantize_causal_encoder
|
| 33 |
+
self.quantize_causal_block_size = quantize_causal_block_size
|
| 34 |
+
self.skip_language_detection = skip_language_detection
|
| 35 |
+
self.encoder_causal_attention = encoder_causal_attention
|
| 36 |
+
self.encoder_causal_convolution = encoder_causal_convolution
|
| 37 |
+
super().__init__(**kwargs)
|
models/glm_speech_tokenizer/generation_whisper.py
ADDED
|
@@ -0,0 +1,1828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
import copy
|
| 16 |
+
import math
|
| 17 |
+
import warnings
|
| 18 |
+
import zlib
|
| 19 |
+
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
from torch import nn
|
| 25 |
+
|
| 26 |
+
from transformers.cache_utils import EncoderDecoderCache
|
| 27 |
+
|
| 28 |
+
from transformers.generation.configuration_utils import GenerationConfig
|
| 29 |
+
from transformers.generation.logits_process import (
|
| 30 |
+
LogitsProcessorList,
|
| 31 |
+
SuppressTokensAtBeginLogitsProcessor,
|
| 32 |
+
SuppressTokensLogitsProcessor,
|
| 33 |
+
WhisperNoSpeechDetection,
|
| 34 |
+
WhisperTimeStampLogitsProcessor,
|
| 35 |
+
)
|
| 36 |
+
from transformers.generation.stopping_criteria import StoppingCriteriaList
|
| 37 |
+
from transformers.modeling_outputs import BaseModelOutput
|
| 38 |
+
from transformers.utils import logging
|
| 39 |
+
from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
logger = logging.get_logger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _median_filter(inputs: torch.Tensor, filter_width: int) -> torch.Tensor:
|
| 46 |
+
"""
|
| 47 |
+
Applies a median filter of width `filter_width` along the last dimension of the input.
|
| 48 |
+
|
| 49 |
+
The `inputs` tensor is assumed to be 3- or 4-dimensional.
|
| 50 |
+
"""
|
| 51 |
+
if filter_width <= 0 or filter_width % 2 != 1:
|
| 52 |
+
raise ValueError("`filter_width` should be an odd number")
|
| 53 |
+
|
| 54 |
+
pad_width = filter_width // 2
|
| 55 |
+
if inputs.shape[-1] <= pad_width:
|
| 56 |
+
return inputs
|
| 57 |
+
|
| 58 |
+
# Pad the left and right edges.
|
| 59 |
+
inputs = nn.functional.pad(inputs, (pad_width, pad_width, 0, 0), mode="reflect")
|
| 60 |
+
|
| 61 |
+
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
|
| 62 |
+
result = inputs.unfold(-1, filter_width, 1).sort()[0][..., pad_width]
|
| 63 |
+
return result
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _dynamic_time_warping(matrix: np.ndarray):
|
| 67 |
+
"""
|
| 68 |
+
Measures similarity between two temporal sequences: the input audio and the output tokens. Used to generate
|
| 69 |
+
token-level timestamps.
|
| 70 |
+
"""
|
| 71 |
+
output_length, input_length = matrix.shape
|
| 72 |
+
cost = np.ones((output_length + 1, input_length + 1), dtype=np.float32) * np.inf
|
| 73 |
+
trace = -np.ones((output_length + 1, input_length + 1), dtype=np.float32)
|
| 74 |
+
|
| 75 |
+
cost[0, 0] = 0
|
| 76 |
+
for j in range(1, input_length + 1):
|
| 77 |
+
for i in range(1, output_length + 1):
|
| 78 |
+
c0 = cost[i - 1, j - 1]
|
| 79 |
+
c1 = cost[i - 1, j]
|
| 80 |
+
c2 = cost[i, j - 1]
|
| 81 |
+
|
| 82 |
+
if c0 < c1 and c0 < c2:
|
| 83 |
+
c, t = c0, 0
|
| 84 |
+
elif c1 < c0 and c1 < c2:
|
| 85 |
+
c, t = c1, 1
|
| 86 |
+
else:
|
| 87 |
+
c, t = c2, 2
|
| 88 |
+
|
| 89 |
+
cost[i, j] = matrix[i - 1, j - 1] + c
|
| 90 |
+
trace[i, j] = t
|
| 91 |
+
|
| 92 |
+
# backtrace
|
| 93 |
+
i = trace.shape[0] - 1
|
| 94 |
+
j = trace.shape[1] - 1
|
| 95 |
+
trace[0, :] = 2
|
| 96 |
+
trace[:, 0] = 1
|
| 97 |
+
|
| 98 |
+
text_indices = []
|
| 99 |
+
time_indices = []
|
| 100 |
+
while i > 0 or j > 0:
|
| 101 |
+
text_indices.append(i - 1)
|
| 102 |
+
time_indices.append(j - 1)
|
| 103 |
+
if trace[i, j] == 0:
|
| 104 |
+
i -= 1
|
| 105 |
+
j -= 1
|
| 106 |
+
elif trace[i, j] == 1:
|
| 107 |
+
i -= 1
|
| 108 |
+
elif trace[i, j] == 2:
|
| 109 |
+
j -= 1
|
| 110 |
+
else:
|
| 111 |
+
raise RuntimeError(
|
| 112 |
+
f"Internal error in dynamic time warping. Unexpected trace[{i}, {j}]. Please file a bug report."
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
text_indices = np.array(text_indices)[::-1]
|
| 116 |
+
time_indices = np.array(time_indices)[::-1]
|
| 117 |
+
return text_indices, time_indices
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name):
|
| 121 |
+
if logits_processor is not None:
|
| 122 |
+
logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None)
|
| 123 |
+
if logit_processor:
|
| 124 |
+
return getattr(logit_processor, attribute_name, None)
|
| 125 |
+
return None
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _pad_to_max_length(
|
| 129 |
+
current_segments,
|
| 130 |
+
pad_token_id,
|
| 131 |
+
device,
|
| 132 |
+
padding_side="right",
|
| 133 |
+
padding="longest",
|
| 134 |
+
bos_token_tensor=None,
|
| 135 |
+
cut_off_length=None,
|
| 136 |
+
):
|
| 137 |
+
max_total_length = 0
|
| 138 |
+
sequences = []
|
| 139 |
+
|
| 140 |
+
if padding_side not in ["right", "left"]:
|
| 141 |
+
raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")
|
| 142 |
+
|
| 143 |
+
if padding not in ["longest", "max_length"]:
|
| 144 |
+
raise ValueError(f"`padding` must be either 'longest' or 'max_length', not {padding}")
|
| 145 |
+
elif padding == "max_length" and cut_off_length is None:
|
| 146 |
+
raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")
|
| 147 |
+
|
| 148 |
+
for current_segment_list in current_segments:
|
| 149 |
+
if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
|
| 150 |
+
sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
|
| 151 |
+
|
| 152 |
+
if cut_off_length is not None:
|
| 153 |
+
sequence = sequence[-cut_off_length:]
|
| 154 |
+
|
| 155 |
+
if bos_token_tensor is not None:
|
| 156 |
+
sequence = torch.cat([bos_token_tensor, sequence])
|
| 157 |
+
|
| 158 |
+
sequences.append(sequence)
|
| 159 |
+
max_total_length = max(max_total_length, len(sequences[-1]))
|
| 160 |
+
elif bos_token_tensor is not None:
|
| 161 |
+
sequences.append(bos_token_tensor)
|
| 162 |
+
else:
|
| 163 |
+
sequences.append(torch.tensor([], device=device))
|
| 164 |
+
|
| 165 |
+
max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
|
| 166 |
+
for i in range(len(current_segments)):
|
| 167 |
+
pad_length = max_total_length - len(sequences[i])
|
| 168 |
+
pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
|
| 169 |
+
sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)
|
| 170 |
+
|
| 171 |
+
sequences = torch.stack(sequences, dim=0)
|
| 172 |
+
return sequences
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class WhisperGenerationMixin:
|
| 176 |
+
def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None):
|
| 177 |
+
"""
|
| 178 |
+
Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to
|
| 179 |
+
map each output token to a position in the input audio. If `num_frames` is specified, the encoder-decoder
|
| 180 |
+
cross-attentions will be cropped before applying DTW.
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
tensor containing the timestamps in seconds for each predicted token
|
| 184 |
+
"""
|
| 185 |
+
# Create a list with `decoder_layers` elements, each a tensor of shape
|
| 186 |
+
# (batch size, attention_heads, output length, input length).
|
| 187 |
+
cross_attentions = []
|
| 188 |
+
for i in range(self.config.decoder_layers):
|
| 189 |
+
cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2))
|
| 190 |
+
|
| 191 |
+
# Select specific cross-attention layers and heads. This is a tensor
|
| 192 |
+
# of shape (batch size, num selected, output length, input length).
|
| 193 |
+
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
|
| 194 |
+
weights = weights.permute([1, 0, 2, 3])
|
| 195 |
+
|
| 196 |
+
weight_length = None
|
| 197 |
+
|
| 198 |
+
if "beam_indices" in generate_outputs:
|
| 199 |
+
# If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths
|
| 200 |
+
# since the beam search strategy chooses the most probable sequences at the end of the search.
|
| 201 |
+
# In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length
|
| 202 |
+
weight_length = (generate_outputs.beam_indices != -1).sum(-1).max()
|
| 203 |
+
weights = weights[:, :, :weight_length]
|
| 204 |
+
|
| 205 |
+
# If beam index is still -1, it means that the associated token id is EOS
|
| 206 |
+
# We need to replace the index with 0 since index_select gives an error if any of the indexes is -1.
|
| 207 |
+
beam_indices = generate_outputs.beam_indices[:, :weight_length]
|
| 208 |
+
beam_indices = beam_indices.masked_fill(beam_indices == -1, 0)
|
| 209 |
+
|
| 210 |
+
# Select the cross attention from the right beam for each output sequences
|
| 211 |
+
weights = torch.stack(
|
| 212 |
+
[
|
| 213 |
+
torch.index_select(weights[:, :, i, :], dim=0, index=beam_indices[:, i])
|
| 214 |
+
for i in range(beam_indices.shape[1])
|
| 215 |
+
],
|
| 216 |
+
dim=2,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# make sure timestamps are as long as weights
|
| 220 |
+
input_length = weight_length or cross_attentions[0].shape[2]
|
| 221 |
+
timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)[:, : input_length + 1]
|
| 222 |
+
batch_size = timestamps.shape[0]
|
| 223 |
+
|
| 224 |
+
if num_frames is not None:
|
| 225 |
+
# two cases:
|
| 226 |
+
# 1. num_frames is the same for each sample -> compute the DTW matrix for each sample in parallel
|
| 227 |
+
# 2. num_frames is different, compute the DTW matrix for each sample sequentially
|
| 228 |
+
|
| 229 |
+
# we're using np.unique because num_frames can be int/list/tuple
|
| 230 |
+
if isinstance(num_frames, int):
|
| 231 |
+
weights = weights[..., : num_frames // 2]
|
| 232 |
+
|
| 233 |
+
elif isinstance(num_frames, (list, tuple, np.ndarray)) and len(np.unique(num_frames)) == 1:
|
| 234 |
+
weights = weights[..., : num_frames[0] // 2]
|
| 235 |
+
|
| 236 |
+
elif isinstance(num_frames, (torch.Tensor)) and len(torch.unique(num_frames)) == 1:
|
| 237 |
+
weights = weights[..., : num_frames[0] // 2]
|
| 238 |
+
|
| 239 |
+
else:
|
| 240 |
+
# num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences
|
| 241 |
+
repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames)
|
| 242 |
+
num_frames = np.repeat(num_frames, repeat_time)
|
| 243 |
+
|
| 244 |
+
if num_frames is None or isinstance(num_frames, int):
|
| 245 |
+
# Normalize and smoothen the weights.
|
| 246 |
+
std = torch.std(weights, dim=-2, keepdim=True, unbiased=False)
|
| 247 |
+
mean = torch.mean(weights, dim=-2, keepdim=True)
|
| 248 |
+
weights = (weights - mean) / std
|
| 249 |
+
weights = _median_filter(weights, self.config.median_filter_width)
|
| 250 |
+
|
| 251 |
+
# Average the different cross-attention heads.
|
| 252 |
+
weights = weights.mean(dim=1)
|
| 253 |
+
|
| 254 |
+
# Perform dynamic time warping on each element of the batch.
|
| 255 |
+
for batch_idx in range(batch_size):
|
| 256 |
+
if num_frames is not None and isinstance(num_frames, (tuple, list, np.ndarray, torch.Tensor)):
|
| 257 |
+
matrix = weights[batch_idx, ..., : num_frames[batch_idx] // 2]
|
| 258 |
+
|
| 259 |
+
# Normalize and smoothen the weights.
|
| 260 |
+
std = torch.std(matrix, dim=-2, keepdim=True, unbiased=False)
|
| 261 |
+
mean = torch.mean(matrix, dim=-2, keepdim=True)
|
| 262 |
+
matrix = (matrix - mean) / std
|
| 263 |
+
matrix = _median_filter(matrix, self.config.median_filter_width)
|
| 264 |
+
|
| 265 |
+
# Average the different cross-attention heads.
|
| 266 |
+
matrix = matrix.mean(dim=0)
|
| 267 |
+
else:
|
| 268 |
+
matrix = weights[batch_idx]
|
| 269 |
+
|
| 270 |
+
text_indices, time_indices = _dynamic_time_warping(-matrix.cpu().double().numpy())
|
| 271 |
+
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
| 272 |
+
jump_times = time_indices[jumps] * time_precision
|
| 273 |
+
timestamps[batch_idx, 1:] = torch.tensor(jump_times)
|
| 274 |
+
|
| 275 |
+
return timestamps
|
| 276 |
+
|
| 277 |
+
def generate(
|
| 278 |
+
self,
|
| 279 |
+
input_features: Optional[torch.Tensor] = None,
|
| 280 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 281 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
| 282 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
| 283 |
+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
| 284 |
+
synced_gpus: bool = False,
|
| 285 |
+
return_timestamps: Optional[bool] = None,
|
| 286 |
+
task: Optional[str] = None,
|
| 287 |
+
language: Optional[Union[str, List[str]]] = None,
|
| 288 |
+
is_multilingual: Optional[bool] = None,
|
| 289 |
+
prompt_ids: Optional[torch.Tensor] = None,
|
| 290 |
+
prompt_condition_type: Optional[str] = None, # first-segment, all-segments
|
| 291 |
+
condition_on_prev_tokens: Optional[bool] = None,
|
| 292 |
+
temperature: Optional[Union[float, Tuple[float, ...]]] = None,
|
| 293 |
+
compression_ratio_threshold: Optional[float] = None,
|
| 294 |
+
logprob_threshold: Optional[float] = None,
|
| 295 |
+
no_speech_threshold: Optional[float] = None,
|
| 296 |
+
num_segment_frames: Optional[int] = None,
|
| 297 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 298 |
+
time_precision: float = 0.02,
|
| 299 |
+
return_token_timestamps: Optional[bool] = None,
|
| 300 |
+
return_segments: bool = False,
|
| 301 |
+
return_dict_in_generate: Optional[bool] = None,
|
| 302 |
+
**kwargs,
|
| 303 |
+
):
|
| 304 |
+
"""
|
| 305 |
+
Transcribes or translates log-mel input features to a sequence of auto-regressively generated token ids.
|
| 306 |
+
|
| 307 |
+
<Tip warning={true}>
|
| 308 |
+
|
| 309 |
+
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
|
| 310 |
+
model's default generation configuration. You can override any `generation_config` by passing the corresponding
|
| 311 |
+
parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
|
| 312 |
+
|
| 313 |
+
For an overview of generation strategies and code examples, check out the [following
|
| 314 |
+
guide](./generation_strategies).
|
| 315 |
+
|
| 316 |
+
</Tip>
|
| 317 |
+
|
| 318 |
+
Parameters:
|
| 319 |
+
input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):
|
| 320 |
+
Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by
|
| 321 |
+
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
|
| 322 |
+
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
|
| 323 |
+
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
|
| 324 |
+
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.
|
| 325 |
+
generation_config (`~generation.GenerationConfig`, *optional*):
|
| 326 |
+
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
|
| 327 |
+
passed to generate matching the attributes of `generation_config` will override them. If
|
| 328 |
+
`generation_config` is not provided, the default will be used, which had the following loading
|
| 329 |
+
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
| 330 |
+
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
| 331 |
+
default values, whose documentation should be checked to parameterize generation.
|
| 332 |
+
logits_processor (`LogitsProcessorList`, *optional*):
|
| 333 |
+
Custom logits processors that complement the default logits processors built from arguments and
|
| 334 |
+
generation config. If a logit processor is passed that is already created with the arguments or a
|
| 335 |
+
generation config an error is thrown. This feature is intended for advanced users.
|
| 336 |
+
stopping_criteria (`StoppingCriteriaList`, *optional*):
|
| 337 |
+
Custom stopping criteria that complement the default stopping criteria built from arguments and a
|
| 338 |
+
generation config. If a stopping criteria is passed that is already created with the arguments or a
|
| 339 |
+
generation config an error is thrown. This feature is intended for advanced users.
|
| 340 |
+
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
|
| 341 |
+
If provided, this function constraints the beam search to allowed tokens only at each step. If not
|
| 342 |
+
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
|
| 343 |
+
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
|
| 344 |
+
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
|
| 345 |
+
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
|
| 346 |
+
Retrieval](https://arxiv.org/abs/2010.00904).
|
| 347 |
+
synced_gpus (`bool`, *optional*, defaults to `False`):
|
| 348 |
+
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
| 349 |
+
return_timestamps (`bool`, *optional*):
|
| 350 |
+
Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
|
| 351 |
+
task (`str`, *optional*):
|
| 352 |
+
Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids`
|
| 353 |
+
will be updated accordingly.
|
| 354 |
+
language (`str` or list of `str`, *optional*):
|
| 355 |
+
Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. For
|
| 356 |
+
batched generation, a list of language tokens can be passed. You can find all the possible language
|
| 357 |
+
tokens in the `model.generation_config.lang_to_id` dictionary.
|
| 358 |
+
is_multilingual (`bool`, *optional*):
|
| 359 |
+
Whether or not the model is multilingual.
|
| 360 |
+
prompt_ids (`torch.Tensor`, *optional*):
|
| 361 |
+
Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is
|
| 362 |
+
provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
|
| 363 |
+
transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
|
| 364 |
+
correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
|
| 365 |
+
prompt_condition_type (`str`, *optional*):
|
| 366 |
+
Only relevant for long-form transcription. Condition type of `prompt_ids`. 'first-segment' means only the first segment is conditioned on `prompt_ids`. 'all-segments' means each segment is conditioned on `prompt_ids`. Make sure to enable `condition_on_prev_tokens` for 'all-segments'.
|
| 367 |
+
Defaults to 'first-segment'. For short-term transcription only 'first-segment' is possible.
|
| 368 |
+
condition_on_prev_tokens (`bool`, *optional*):
|
| 369 |
+
Only relevant for long-form transcription. Whether to condition each segment on the previous segment.
|
| 370 |
+
As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
|
| 371 |
+
performance.
|
| 372 |
+
temperature (`float` or list of `float`, *optional*):
|
| 373 |
+
The temperature to be used for generation. Passing a single `float` value and `do_sample=True` activates
|
| 374 |
+
generation using sampling. For long-form transcription, temperature fallback can be activated by passing
|
| 375 |
+
a list of float values such as (0.0, 0.2, 0.4, 0.6, 0.8, 1.0). As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
|
| 376 |
+
performance.
|
| 377 |
+
compression_ratio_threshold (`float`, *optional*):
|
| 378 |
+
Only relevant for long-form transcription. If defined, the zlib compression rate of each segment will be computed. If the compression rate of
|
| 379 |
+
a segment is higher than `compression_ratio_threshold`, temperature fallback is activated: the generated segment is discarded and the generation is
|
| 380 |
+
repeated using a higher temperature. The intuition behind this feature is that segments with very high compression rates
|
| 381 |
+
suffer from a lot of repetition. The unwanted repetition can be reduced by injecting more randomness by increasing the temperature. If `compression_ratio_threshold` is defined
|
| 382 |
+
make sure that `temperature` is a list of values. A common value for `compression_ratio_threshold` is 1.35.
|
| 383 |
+
As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
|
| 384 |
+
performance.
|
| 385 |
+
logprob_threshold (`float`, *optional*):
|
| 386 |
+
Only relevant for long-form transcription. If defined, the average log-probability of each segment will be computed. If the log-probability of
|
| 387 |
+
a given segment is lower than `logprob_threshold`, temperature fallback is activated: the generated segment is discarded and the generation is
|
| 388 |
+
repeated using a higher temperature. The intuition behind this feature is that segments of low log-probability
|
| 389 |
+
can be improved by injecting more randomness by increasing the temperature. If `logprob_threshold` is defined
|
| 390 |
+
make sure that `temperature` is a list of values. A common value for `logprob_threshold` is -1.0.
|
| 391 |
+
As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
|
| 392 |
+
performance.
|
| 393 |
+
no_speech_threshold (`float`, *optional*):
|
| 394 |
+
Only relevant for long-form transcription. If defined, the "no-speech" token combined with the `logprob_threshold`
|
| 395 |
+
is used to determine whether a segment contains only silence. In this case, the transcription for this segment
|
| 396 |
+
is skipped.
|
| 397 |
+
As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
|
| 398 |
+
performance.
|
| 399 |
+
num_segment_frames (`int`, *optional*):
|
| 400 |
+
The number of frames a single segment is made of. If not defined, `num_segment_frames` defaults to the model's stride
|
| 401 |
+
times the maximum input length.
|
| 402 |
+
attention_mask (`torch.Tensor`, *optional*):
|
| 403 |
+
`attention_mask` needs to be passed when doing long-form transcription using a batch size > 1.
|
| 404 |
+
time_precision (`int`, *optional*, defaults to 0.02):
|
| 405 |
+
The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts
|
| 406 |
+
for 20 ms.
|
| 407 |
+
return_token_timestamps (`bool`, *optional*):
|
| 408 |
+
Whether to return token-level timestamps with the text. This can be used with or without the
|
| 409 |
+
`return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
|
| 410 |
+
words.
|
| 411 |
+
return_segments (`bool`, *optional*, defaults to `False`):
|
| 412 |
+
Whether to additionally return a list of all segments. Note that this option can only be enabled
|
| 413 |
+
when doing long-form transcription.
|
| 414 |
+
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
| 415 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of just returning the generated tokens.
|
| 416 |
+
Note that when doing long-form transcription, `return_dict_in_generate` can only be enabled when
|
| 417 |
+
`return_segments` is set True. In this case the generation outputs of each segment is added to each
|
| 418 |
+
segment.
|
| 419 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
| 420 |
+
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
| 421 |
+
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
| 422 |
+
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
|
| 423 |
+
|
| 424 |
+
Return:
|
| 425 |
+
[`~utils.ModelOutput`] or `torch.LongTensor` or `Dict[str, Any]`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
|
| 426 |
+
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor` or a dict of segments when `return_segments=True`.
|
| 427 |
+
|
| 428 |
+
If the passed input is > 30 seconds / > 3000 mel input features and `return_segments=True` then a dictionary of generated sequence ids, called `sequences` and a list of each generated segment is returned.
|
| 429 |
+
|
| 430 |
+
else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are:
|
| 431 |
+
|
| 432 |
+
- [`~generation.GenerateEncoderDecoderOutput`],
|
| 433 |
+
- [`~generation.GenerateBeamEncoderDecoderOutput`]
|
| 434 |
+
|
| 435 |
+
else only the generated output sequence ids are returned.
|
| 436 |
+
|
| 437 |
+
Example:
|
| 438 |
+
|
| 439 |
+
- *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate.
|
| 440 |
+
|
| 441 |
+
```python
|
| 442 |
+
>>> import torch
|
| 443 |
+
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
|
| 444 |
+
>>> from datasets import load_dataset, Audio
|
| 445 |
+
|
| 446 |
+
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
|
| 447 |
+
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
| 448 |
+
>>> model.cuda() # doctest: +IGNORE_RESULT
|
| 449 |
+
|
| 450 |
+
>>> # load audios > 30 seconds
|
| 451 |
+
>>> ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
|
| 452 |
+
>>> # resample to 16kHz
|
| 453 |
+
>>> ds = ds.cast_column("audio", Audio(sampling_rate=16000))
|
| 454 |
+
>>> # take first 8 audios and retrieve array
|
| 455 |
+
>>> audio = ds[:8]["audio"]
|
| 456 |
+
>>> audio = [x["array"] for x in audio]
|
| 457 |
+
|
| 458 |
+
>>> # make sure to NOT truncate the input audio, to return the `attention_mask` and to pad to the longest audio
|
| 459 |
+
>>> inputs = processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000)
|
| 460 |
+
>>> inputs = inputs.to("cuda", torch.float32)
|
| 461 |
+
|
| 462 |
+
>>> # transcribe audio to ids
|
| 463 |
+
>>> generated_ids = model.generate(**inputs)
|
| 464 |
+
|
| 465 |
+
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| 466 |
+
>>> transcription[0]
|
| 467 |
+
" Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile."
|
| 468 |
+
```
|
| 469 |
+
|
| 470 |
+
- *Shortform transcription*: If passed mel input features are < 30 seconds, the whole audio will be transcribed with a single call to generate.
|
| 471 |
+
|
| 472 |
+
```python
|
| 473 |
+
>>> import torch
|
| 474 |
+
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
|
| 475 |
+
>>> from datasets import load_dataset
|
| 476 |
+
|
| 477 |
+
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
|
| 478 |
+
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
| 479 |
+
|
| 480 |
+
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
| 481 |
+
|
| 482 |
+
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
|
| 483 |
+
>>> input_features = inputs.input_features
|
| 484 |
+
|
| 485 |
+
>>> generated_ids = model.generate(inputs=input_features)
|
| 486 |
+
|
| 487 |
+
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 488 |
+
>>> transcription
|
| 489 |
+
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
|
| 490 |
+
```
|
| 491 |
+
|
| 492 |
+
"""
|
| 493 |
+
# 0. deprecate old inputs
|
| 494 |
+
if "inputs" in kwargs:
|
| 495 |
+
input_features = kwargs.pop("inputs")
|
| 496 |
+
warnings.warn(
|
| 497 |
+
"The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
|
| 498 |
+
FutureWarning,
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
# 1. prepare generation config
|
| 502 |
+
generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
| 503 |
+
|
| 504 |
+
# 2. set global generate variables
|
| 505 |
+
input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
|
| 506 |
+
num_segment_frames = input_stride * self.config.max_source_positions
|
| 507 |
+
batch_size, total_input_frames = self._retrieve_total_input_frames(
|
| 508 |
+
input_features=input_features, input_stride=input_stride, kwargs=kwargs
|
| 509 |
+
)
|
| 510 |
+
is_shortform = total_input_frames <= num_segment_frames
|
| 511 |
+
|
| 512 |
+
# 3. Make sure generation config is correctly set
|
| 513 |
+
# Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
|
| 514 |
+
return_dict_in_generate = self._set_return_outputs(
|
| 515 |
+
return_dict_in_generate=return_dict_in_generate,
|
| 516 |
+
return_token_timestamps=return_token_timestamps,
|
| 517 |
+
logprob_threshold=logprob_threshold,
|
| 518 |
+
generation_config=generation_config,
|
| 519 |
+
)
|
| 520 |
+
timestamp_begin = self._set_return_timestamps(
|
| 521 |
+
return_timestamps=return_timestamps, is_shortform=is_shortform, generation_config=generation_config
|
| 522 |
+
)
|
| 523 |
+
self._set_language_and_task(
|
| 524 |
+
language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
|
| 525 |
+
)
|
| 526 |
+
self._set_num_frames(
|
| 527 |
+
return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs
|
| 528 |
+
)
|
| 529 |
+
self._set_thresholds_and_condition(
|
| 530 |
+
generation_config=generation_config,
|
| 531 |
+
logprob_threshold=logprob_threshold,
|
| 532 |
+
compression_ratio_threshold=compression_ratio_threshold,
|
| 533 |
+
no_speech_threshold=no_speech_threshold,
|
| 534 |
+
condition_on_prev_tokens=condition_on_prev_tokens,
|
| 535 |
+
)
|
| 536 |
+
self._set_prompt_condition_type(
|
| 537 |
+
generation_config=generation_config,
|
| 538 |
+
prompt_condition_type=prompt_condition_type,
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
kwargs["attention_mask"] = attention_mask
|
| 542 |
+
# pass self.config for backward compatibility
|
| 543 |
+
init_tokens = self._retrieve_init_tokens(
|
| 544 |
+
input_features,
|
| 545 |
+
batch_size=batch_size,
|
| 546 |
+
generation_config=generation_config,
|
| 547 |
+
config=self.config,
|
| 548 |
+
num_segment_frames=num_segment_frames,
|
| 549 |
+
kwargs=kwargs,
|
| 550 |
+
)
|
| 551 |
+
# passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
|
| 552 |
+
# where the input ids are handled explicitly by the generate method
|
| 553 |
+
self._check_decoder_input_ids(kwargs=kwargs)
|
| 554 |
+
|
| 555 |
+
# 3. Retrieve logits processors
|
| 556 |
+
device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
|
| 557 |
+
begin_index = init_tokens.shape[1]
|
| 558 |
+
logits_processor = self._retrieve_logit_processors(
|
| 559 |
+
generation_config=generation_config,
|
| 560 |
+
logits_processor=logits_processor,
|
| 561 |
+
begin_index=begin_index, # begin index is index of first generated decoder token
|
| 562 |
+
num_beams=kwargs.get("num_beams", 1),
|
| 563 |
+
device=device,
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
# 4 Set and retrieve global generation variables
|
| 567 |
+
self._set_condition_on_prev_tokens(
|
| 568 |
+
condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature
|
| 572 |
+
temperature = temperatures[0]
|
| 573 |
+
|
| 574 |
+
max_frames, seek = self._retrieve_max_frames_and_seek(
|
| 575 |
+
batch_size=batch_size,
|
| 576 |
+
attention_mask=attention_mask,
|
| 577 |
+
total_input_frames=total_input_frames,
|
| 578 |
+
is_shortform=is_shortform,
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
# 5 Prepare running variables, list for generation
|
| 582 |
+
num_return_sequences = generation_config.num_return_sequences
|
| 583 |
+
(
|
| 584 |
+
batch_idx_map,
|
| 585 |
+
cur_bsz,
|
| 586 |
+
input_features,
|
| 587 |
+
seek,
|
| 588 |
+
max_frames,
|
| 589 |
+
init_tokens,
|
| 590 |
+
do_condition_on_prev_tokens,
|
| 591 |
+
) = self._expand_variables_for_generation(
|
| 592 |
+
input_features=input_features,
|
| 593 |
+
seek=seek,
|
| 594 |
+
max_frames=max_frames,
|
| 595 |
+
init_tokens=init_tokens,
|
| 596 |
+
batch_size=batch_size,
|
| 597 |
+
condition_on_prev_tokens=condition_on_prev_tokens,
|
| 598 |
+
generation_config=generation_config,
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
current_segments = self._prepare_segments(
|
| 602 |
+
prompt_ids=prompt_ids,
|
| 603 |
+
batch_size=cur_bsz,
|
| 604 |
+
generation_config=generation_config,
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
# 6 Transcribe audio until we reach the end of all input audios
|
| 608 |
+
while (seek < max_frames).any():
|
| 609 |
+
# 6.1 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop
|
| 610 |
+
# in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order
|
| 611 |
+
# to know which original audio is being decoded
|
| 612 |
+
# Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk
|
| 613 |
+
input_features, cur_bsz, batch_idx_map = self._maybe_reduce_batch(
|
| 614 |
+
input_features=input_features,
|
| 615 |
+
seek=seek,
|
| 616 |
+
max_frames=max_frames,
|
| 617 |
+
cur_bsz=cur_bsz,
|
| 618 |
+
batch_idx_map=batch_idx_map,
|
| 619 |
+
)
|
| 620 |
+
time_offset = seek * time_precision / input_stride
|
| 621 |
+
seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
|
| 622 |
+
|
| 623 |
+
# 6.2 cut out next 30s segment from input features
|
| 624 |
+
segment_input = self._get_input_segment(
|
| 625 |
+
input_features=input_features,
|
| 626 |
+
seek=seek,
|
| 627 |
+
seek_num_frames=seek_num_frames,
|
| 628 |
+
num_segment_frames=num_segment_frames,
|
| 629 |
+
cur_bsz=cur_bsz,
|
| 630 |
+
batch_idx_map=batch_idx_map,
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
# 6.3 prepare decoder input ids
|
| 634 |
+
suppress_tokens = _get_attr_from_logit_processors(
|
| 635 |
+
logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens"
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
decoder_input_ids, kwargs = self._prepare_decoder_input_ids(
|
| 639 |
+
cur_bsz=cur_bsz,
|
| 640 |
+
init_tokens=init_tokens,
|
| 641 |
+
current_segments=current_segments,
|
| 642 |
+
batch_idx_map=batch_idx_map,
|
| 643 |
+
do_condition_on_prev_tokens=do_condition_on_prev_tokens,
|
| 644 |
+
prompt_ids=prompt_ids,
|
| 645 |
+
generation_config=generation_config,
|
| 646 |
+
config=self.config,
|
| 647 |
+
device=init_tokens.device,
|
| 648 |
+
suppress_tokens=suppress_tokens,
|
| 649 |
+
kwargs=kwargs,
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
# 6.4 set max new tokens or max length
|
| 653 |
+
self._set_max_new_tokens_and_length(
|
| 654 |
+
config=self.config,
|
| 655 |
+
decoder_input_ids=decoder_input_ids,
|
| 656 |
+
generation_config=generation_config,
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
# 6.5 Set current `begin_index` for all logit processors
|
| 660 |
+
if logits_processor is not None:
|
| 661 |
+
for proc in logits_processor:
|
| 662 |
+
if hasattr(proc, "set_begin_index"):
|
| 663 |
+
proc.set_begin_index(decoder_input_ids.shape[-1])
|
| 664 |
+
|
| 665 |
+
# 6.6 Run generate with fallback
|
| 666 |
+
(
|
| 667 |
+
seek_sequences,
|
| 668 |
+
seek_outputs,
|
| 669 |
+
should_skip,
|
| 670 |
+
do_condition_on_prev_tokens,
|
| 671 |
+
model_output_type,
|
| 672 |
+
) = self.generate_with_fallback(
|
| 673 |
+
segment_input=segment_input,
|
| 674 |
+
decoder_input_ids=decoder_input_ids,
|
| 675 |
+
cur_bsz=cur_bsz,
|
| 676 |
+
batch_idx_map=batch_idx_map,
|
| 677 |
+
seek=seek,
|
| 678 |
+
num_segment_frames=num_segment_frames,
|
| 679 |
+
max_frames=max_frames,
|
| 680 |
+
temperatures=temperatures,
|
| 681 |
+
generation_config=generation_config,
|
| 682 |
+
logits_processor=logits_processor,
|
| 683 |
+
stopping_criteria=stopping_criteria,
|
| 684 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
| 685 |
+
synced_gpus=synced_gpus,
|
| 686 |
+
return_token_timestamps=return_token_timestamps,
|
| 687 |
+
do_condition_on_prev_tokens=do_condition_on_prev_tokens,
|
| 688 |
+
is_shortform=is_shortform,
|
| 689 |
+
batch_size=batch_size,
|
| 690 |
+
kwargs=kwargs,
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
# 6.7 In every generated sequence, split by timestamp tokens and extract segments
|
| 694 |
+
for i, seek_sequence in enumerate(seek_sequences):
|
| 695 |
+
prev_i = batch_idx_map[i]
|
| 696 |
+
|
| 697 |
+
if should_skip[i]:
|
| 698 |
+
seek[prev_i] += seek_num_frames[prev_i]
|
| 699 |
+
continue
|
| 700 |
+
|
| 701 |
+
segments, segment_offset = self._retrieve_segment(
|
| 702 |
+
seek_sequence=seek_sequence,
|
| 703 |
+
seek_outputs=seek_outputs,
|
| 704 |
+
time_offset=time_offset,
|
| 705 |
+
timestamp_begin=timestamp_begin,
|
| 706 |
+
seek_num_frames=seek_num_frames,
|
| 707 |
+
time_precision=time_precision,
|
| 708 |
+
input_stride=input_stride,
|
| 709 |
+
prev_idx=prev_i,
|
| 710 |
+
idx=i,
|
| 711 |
+
return_token_timestamps=return_token_timestamps,
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
current_segments[prev_i] += segments
|
| 715 |
+
|
| 716 |
+
if is_shortform:
|
| 717 |
+
seek[prev_i] += max_frames[i]
|
| 718 |
+
else:
|
| 719 |
+
seek[prev_i] += segment_offset
|
| 720 |
+
|
| 721 |
+
# 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
|
| 722 |
+
# output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
|
| 723 |
+
final_segments = (
|
| 724 |
+
[x[1:] for x in current_segments]
|
| 725 |
+
if (prompt_ids is not None and generation_config.prompt_condition_type == "first-segment")
|
| 726 |
+
else current_segments
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
sequences = _pad_to_max_length(
|
| 730 |
+
final_segments, generation_config.pad_token_id, device=self.device, padding_side="right"
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
|
| 734 |
+
if return_segments:
|
| 735 |
+
return {"sequences": sequences, "segments": final_segments}
|
| 736 |
+
|
| 737 |
+
if is_shortform:
|
| 738 |
+
# add eos token:
|
| 739 |
+
if generation_config.max_new_tokens is None and generation_config.max_length is None:
|
| 740 |
+
eos_tokens = torch.full((sequences.shape[0], 1), generation_config.eos_token_id)
|
| 741 |
+
sequences = torch.cat([sequences, eos_tokens], dim=-1)
|
| 742 |
+
|
| 743 |
+
if return_token_timestamps:
|
| 744 |
+
outputs = {}
|
| 745 |
+
outputs["sequences"] = sequences
|
| 746 |
+
outputs["token_timestamps"] = torch.stack([d["token_timestamps"] for d in seek_outputs], dim=0)
|
| 747 |
+
else:
|
| 748 |
+
outputs = sequences
|
| 749 |
+
|
| 750 |
+
if return_dict_in_generate and generation_config.return_dict_in_generate:
|
| 751 |
+
dict_outputs = self._stack_split_outputs(seek_outputs, model_output_type, sequences.device, kwargs)
|
| 752 |
+
|
| 753 |
+
if num_return_sequences > 1:
|
| 754 |
+
if hasattr(dict_outputs, "encoder_attentions") and dict_outputs.encoder_attentions is not None:
|
| 755 |
+
dict_outputs.encoder_attentions = tuple(
|
| 756 |
+
dict_outputs.encoder_attentions[i][::num_return_sequences]
|
| 757 |
+
for i in range(len(dict_outputs.encoder_attentions))
|
| 758 |
+
)
|
| 759 |
+
if (
|
| 760 |
+
hasattr(dict_outputs, "encoder_hidden_states")
|
| 761 |
+
and dict_outputs.encoder_hidden_states is not None
|
| 762 |
+
):
|
| 763 |
+
dict_outputs.encoder_hidden_states = tuple(
|
| 764 |
+
dict_outputs.encoder_hidden_states[i][::num_return_sequences]
|
| 765 |
+
for i in range(len(dict_outputs.encoder_hidden_states))
|
| 766 |
+
)
|
| 767 |
+
if return_token_timestamps:
|
| 768 |
+
dict_outputs["token_timestamps"] = outputs["token_timestamps"]
|
| 769 |
+
return dict_outputs
|
| 770 |
+
|
| 771 |
+
return outputs
|
| 772 |
+
|
| 773 |
+
return sequences
|
| 774 |
+
|
| 775 |
+
def generate_with_fallback(
|
| 776 |
+
self,
|
| 777 |
+
segment_input,
|
| 778 |
+
decoder_input_ids,
|
| 779 |
+
cur_bsz,
|
| 780 |
+
batch_idx_map,
|
| 781 |
+
seek,
|
| 782 |
+
num_segment_frames,
|
| 783 |
+
max_frames,
|
| 784 |
+
temperatures,
|
| 785 |
+
generation_config,
|
| 786 |
+
logits_processor,
|
| 787 |
+
stopping_criteria,
|
| 788 |
+
prefix_allowed_tokens_fn,
|
| 789 |
+
synced_gpus,
|
| 790 |
+
return_token_timestamps,
|
| 791 |
+
do_condition_on_prev_tokens,
|
| 792 |
+
is_shortform,
|
| 793 |
+
batch_size,
|
| 794 |
+
kwargs,
|
| 795 |
+
):
|
| 796 |
+
kwargs = copy.copy(kwargs)
|
| 797 |
+
|
| 798 |
+
# 6.6 Batch generate current chunk
|
| 799 |
+
seek_sequence_list = [None for _ in range(cur_bsz)]
|
| 800 |
+
seek_outputs_list = [None for _ in range(cur_bsz)]
|
| 801 |
+
needs_fallback = [False for _ in range(cur_bsz)]
|
| 802 |
+
should_skip = [False for _ in range(cur_bsz)]
|
| 803 |
+
fallback_index_map = list(range(cur_bsz))
|
| 804 |
+
if generation_config.no_speech_threshold is not None:
|
| 805 |
+
self._setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs)
|
| 806 |
+
|
| 807 |
+
for fallback_idx, temperature in enumerate(temperatures):
|
| 808 |
+
generation_config.do_sample = temperature is not None and temperature > 0.0
|
| 809 |
+
generation_config.temperature = temperature if generation_config.do_sample else 1.0
|
| 810 |
+
if generation_config.do_sample:
|
| 811 |
+
generation_config.num_beams = 1
|
| 812 |
+
|
| 813 |
+
generate_kwargs = copy.copy(kwargs)
|
| 814 |
+
for key in ["do_sample", "temperature", "num_beams"]:
|
| 815 |
+
if key in generate_kwargs:
|
| 816 |
+
del generate_kwargs[key]
|
| 817 |
+
|
| 818 |
+
cur_bsz = decoder_input_ids.shape[0]
|
| 819 |
+
if generation_config.cache_implementation == "static" and cur_bsz < batch_size:
|
| 820 |
+
segment_input = F.pad(segment_input, (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0)
|
| 821 |
+
decoder_input_ids = F.pad(
|
| 822 |
+
decoder_input_ids, (0, 0, 0, batch_size - cur_bsz), value=generation_config.pad_token_id
|
| 823 |
+
)
|
| 824 |
+
if generate_kwargs.get("decoder_attention_mask") is not None:
|
| 825 |
+
generate_kwargs["decoder_attention_mask"] = F.pad(
|
| 826 |
+
generate_kwargs["decoder_attention_mask"], (0, 0, 0, batch_size - cur_bsz), value=True
|
| 827 |
+
)
|
| 828 |
+
if generate_kwargs.get("encoder_outputs") is not None:
|
| 829 |
+
generate_kwargs["encoder_outputs"] = F.pad(
|
| 830 |
+
generate_kwargs["encoder_outputs"], (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
seek_outputs = super().generate(
|
| 834 |
+
segment_input,
|
| 835 |
+
generation_config=generation_config,
|
| 836 |
+
logits_processor=logits_processor,
|
| 837 |
+
stopping_criteria=stopping_criteria,
|
| 838 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
| 839 |
+
synced_gpus=synced_gpus,
|
| 840 |
+
decoder_input_ids=decoder_input_ids,
|
| 841 |
+
**generate_kwargs,
|
| 842 |
+
)
|
| 843 |
+
|
| 844 |
+
model_output_type = type(seek_outputs)
|
| 845 |
+
|
| 846 |
+
# post-process sequence tokens and outputs to be in list form
|
| 847 |
+
seek_sequences, seek_outputs = self._postprocess_outputs(
|
| 848 |
+
seek_outputs=seek_outputs,
|
| 849 |
+
decoder_input_ids=decoder_input_ids,
|
| 850 |
+
return_token_timestamps=return_token_timestamps,
|
| 851 |
+
generation_config=generation_config,
|
| 852 |
+
is_shortform=is_shortform,
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
if cur_bsz < batch_size:
|
| 856 |
+
seek_sequences = seek_sequences[:cur_bsz]
|
| 857 |
+
seek_outputs = seek_outputs[:cur_bsz]
|
| 858 |
+
|
| 859 |
+
# 6.7 Extract cut sequences from every sequence and check if fallback should be applied
|
| 860 |
+
# Loop over each decoded audio individually as each decoding can be of a different length
|
| 861 |
+
new_fallback_index_map = []
|
| 862 |
+
new_segment_input = []
|
| 863 |
+
new_decoder_input_ids = []
|
| 864 |
+
new_decoder_attention_mask = []
|
| 865 |
+
|
| 866 |
+
for i, seek_sequence in enumerate(seek_sequences):
|
| 867 |
+
# make sure we cut a predicted EOS token if we are not finished with the generation yet
|
| 868 |
+
prev_i = batch_idx_map[fallback_index_map[i]]
|
| 869 |
+
is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i]
|
| 870 |
+
|
| 871 |
+
# remove eos token id
|
| 872 |
+
if is_not_final and seek_sequence[-1] == generation_config.eos_token_id:
|
| 873 |
+
seek_sequence = seek_sequence[:-1]
|
| 874 |
+
if return_token_timestamps and not is_shortform:
|
| 875 |
+
seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-1]
|
| 876 |
+
|
| 877 |
+
# remove all padding tokens
|
| 878 |
+
if seek_sequence[-1] == generation_config.pad_token_id:
|
| 879 |
+
num_paddings = (seek_sequence == generation_config.pad_token_id).sum()
|
| 880 |
+
seek_sequence = seek_sequence[:-num_paddings]
|
| 881 |
+
if return_token_timestamps and not is_shortform:
|
| 882 |
+
seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-num_paddings]
|
| 883 |
+
|
| 884 |
+
# check which sequences in batch need fallback & which should be skipped
|
| 885 |
+
needs_fallback[i], should_skip[i] = self._need_fallback(
|
| 886 |
+
seek_sequence,
|
| 887 |
+
seek_outputs,
|
| 888 |
+
i,
|
| 889 |
+
logits_processor,
|
| 890 |
+
generation_config,
|
| 891 |
+
self.config.vocab_size,
|
| 892 |
+
temperature,
|
| 893 |
+
)
|
| 894 |
+
|
| 895 |
+
seek_sequence_list[fallback_index_map[i]] = seek_sequence
|
| 896 |
+
seek_outputs_list[fallback_index_map[i]] = seek_outputs[i]
|
| 897 |
+
is_low_temperature = temperature is None or temperature < 0.5
|
| 898 |
+
do_condition_on_prev_tokens[fallback_index_map[i]] = (
|
| 899 |
+
generation_config.condition_on_prev_tokens and is_low_temperature
|
| 900 |
+
)
|
| 901 |
+
|
| 902 |
+
if needs_fallback[i]:
|
| 903 |
+
new_fallback_index_map.append(fallback_index_map[i])
|
| 904 |
+
new_segment_input.append(segment_input[i])
|
| 905 |
+
new_decoder_input_ids.append(decoder_input_ids[i])
|
| 906 |
+
if "decoder_attention_mask" in kwargs:
|
| 907 |
+
new_decoder_attention_mask.append(kwargs["decoder_attention_mask"][i])
|
| 908 |
+
|
| 909 |
+
fallback_index_map = new_fallback_index_map
|
| 910 |
+
|
| 911 |
+
# if no sequence needs to be run with temperature fallback, we're finished
|
| 912 |
+
if len(fallback_index_map) == 0 or fallback_idx == len(temperatures) - 1:
|
| 913 |
+
seek_sequences = seek_sequence_list
|
| 914 |
+
seek_outputs = seek_outputs_list
|
| 915 |
+
break
|
| 916 |
+
|
| 917 |
+
# if we're still in the loop, make sure that decoder_input_ids and segment inputs are tensors
|
| 918 |
+
decoder_input_ids = torch.stack(new_decoder_input_ids)
|
| 919 |
+
segment_input = torch.stack(new_segment_input)
|
| 920 |
+
if "decoder_attention_mask" in kwargs:
|
| 921 |
+
kwargs["decoder_attention_mask"] = torch.stack(new_decoder_attention_mask)
|
| 922 |
+
|
| 923 |
+
return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens, model_output_type
|
| 924 |
+
|
| 925 |
+
@staticmethod
|
| 926 |
+
def _prepare_segments(prompt_ids, batch_size, generation_config):
|
| 927 |
+
if prompt_ids is not None and generation_config.prompt_condition_type == "first-segment":
|
| 928 |
+
prev_sot_token_id = getattr(generation_config, "prev_sot_token_id", None)
|
| 929 |
+
prompt_ids = prompt_ids[1:] if prompt_ids[0] == prev_sot_token_id else prompt_ids
|
| 930 |
+
current_segments = [[{"tokens": prompt_ids}] for _ in range(batch_size)]
|
| 931 |
+
else:
|
| 932 |
+
current_segments = [[] for _ in range(batch_size)]
|
| 933 |
+
|
| 934 |
+
return current_segments
|
| 935 |
+
|
| 936 |
+
def _postprocess_outputs(
|
| 937 |
+
self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config, is_shortform
|
| 938 |
+
):
|
| 939 |
+
# remove all previously passed decoder input ids
|
| 940 |
+
start_idx = decoder_input_ids.shape[-1] if not is_shortform else torch.tensor(0)
|
| 941 |
+
|
| 942 |
+
if isinstance(seek_outputs, torch.Tensor):
|
| 943 |
+
seek_outputs = seek_outputs[:, start_idx:]
|
| 944 |
+
return seek_outputs, seek_outputs
|
| 945 |
+
|
| 946 |
+
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
| 947 |
+
num_frames = getattr(generation_config, "num_frames", None)
|
| 948 |
+
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
|
| 949 |
+
seek_outputs, generation_config.alignment_heads, num_frames=num_frames
|
| 950 |
+
)
|
| 951 |
+
seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, start_idx:]
|
| 952 |
+
|
| 953 |
+
seek_outputs["sequences"] = seek_outputs["sequences"][:, start_idx:]
|
| 954 |
+
|
| 955 |
+
def split_by_batch_index(values, key, batch_idx, is_shortform):
|
| 956 |
+
if key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
|
| 957 |
+
return [v[batch_idx].cpu() for v in values]
|
| 958 |
+
if key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]:
|
| 959 |
+
return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values)
|
| 960 |
+
elif key == "past_key_values":
|
| 961 |
+
if not is_shortform:
|
| 962 |
+
# we don't save `past_key_values` as this is too costly for longform
|
| 963 |
+
return None
|
| 964 |
+
elif isinstance(values, EncoderDecoderCache):
|
| 965 |
+
all_past_key_values = []
|
| 966 |
+
for layer_idx in range(self.config.decoder_layers):
|
| 967 |
+
layer_past_key_values = []
|
| 968 |
+
for cache_cls in [values.self_attention_cache, values.cross_attention_cache]:
|
| 969 |
+
for v in [cache_cls.key_cache, cache_cls.value_cache]:
|
| 970 |
+
layer_past_key_values.append(v[layer_idx][batch_idx][None].cpu())
|
| 971 |
+
all_past_key_values.append(tuple(layer_past_key_values))
|
| 972 |
+
return tuple(all_past_key_values)
|
| 973 |
+
else:
|
| 974 |
+
all_past_key_values = []
|
| 975 |
+
for v in range(len(values)):
|
| 976 |
+
layer_past_key_values = []
|
| 977 |
+
for w in values[v]:
|
| 978 |
+
layer_past_key_values.append(w[batch_idx][None].cpu())
|
| 979 |
+
all_past_key_values.append(tuple(layer_past_key_values))
|
| 980 |
+
return tuple(all_past_key_values)
|
| 981 |
+
|
| 982 |
+
return values[batch_idx].cpu()
|
| 983 |
+
|
| 984 |
+
sequence_tokens = seek_outputs["sequences"]
|
| 985 |
+
seek_outputs = [
|
| 986 |
+
{k: split_by_batch_index(v, k, i, is_shortform) for k, v in seek_outputs.items()}
|
| 987 |
+
for i in range(sequence_tokens.shape[0])
|
| 988 |
+
]
|
| 989 |
+
|
| 990 |
+
return sequence_tokens, seek_outputs
|
| 991 |
+
|
| 992 |
+
def _stack_split_outputs(self, seek_outputs, model_output_type, device, kwargs):
|
| 993 |
+
# Stack back seek_outputs tensors after splitting them with the split_by_batch_index method
|
| 994 |
+
outputs = {}
|
| 995 |
+
for key in seek_outputs[0].keys():
|
| 996 |
+
if key == "sequences":
|
| 997 |
+
outputs[key] = torch.stack([v[key] for v in seek_outputs], dim=0).to(device)
|
| 998 |
+
if key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
|
| 999 |
+
outputs[key] = tuple(
|
| 1000 |
+
torch.stack([v[key][i] for v in seek_outputs]).to(device) for i in range(len(seek_outputs[0][key]))
|
| 1001 |
+
)
|
| 1002 |
+
if key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]:
|
| 1003 |
+
outputs[key] = tuple(
|
| 1004 |
+
tuple(
|
| 1005 |
+
torch.stack([v[key][i][j] for v in seek_outputs]).squeeze(1).to(device)
|
| 1006 |
+
for j in range(len(seek_outputs[0][key][0]))
|
| 1007 |
+
)
|
| 1008 |
+
for i in range(len(seek_outputs[0][key]))
|
| 1009 |
+
)
|
| 1010 |
+
if key == "past_key_values":
|
| 1011 |
+
past_key_value_type = kwargs.get("past_key_values")
|
| 1012 |
+
if seek_outputs[0][key] is not None:
|
| 1013 |
+
outputs[key] = tuple(
|
| 1014 |
+
tuple(
|
| 1015 |
+
torch.stack([v[key][i][j] for v in seek_outputs]).squeeze(1).to(device)
|
| 1016 |
+
for j in range(len(seek_outputs[0][key][0]))
|
| 1017 |
+
)
|
| 1018 |
+
for i in range(len(seek_outputs[0][key]))
|
| 1019 |
+
)
|
| 1020 |
+
if past_key_value_type is not None and isinstance(past_key_value_type, EncoderDecoderCache):
|
| 1021 |
+
outputs[key] = past_key_value_type.from_legacy_cache(outputs[key])
|
| 1022 |
+
else:
|
| 1023 |
+
outputs[key] = None
|
| 1024 |
+
|
| 1025 |
+
return model_output_type(**outputs)
|
| 1026 |
+
|
| 1027 |
+
def _need_fallback(
|
| 1028 |
+
self,
|
| 1029 |
+
seek_sequence,
|
| 1030 |
+
seek_outputs,
|
| 1031 |
+
index,
|
| 1032 |
+
logits_processor,
|
| 1033 |
+
generation_config,
|
| 1034 |
+
vocab_size,
|
| 1035 |
+
temperature,
|
| 1036 |
+
):
|
| 1037 |
+
needs_fallback = False
|
| 1038 |
+
should_skip = False
|
| 1039 |
+
if generation_config.compression_ratio_threshold is not None:
|
| 1040 |
+
compression_ratio = self._retrieve_compression_ratio(seek_sequence, vocab_size)
|
| 1041 |
+
|
| 1042 |
+
if compression_ratio > generation_config.compression_ratio_threshold:
|
| 1043 |
+
needs_fallback = True
|
| 1044 |
+
|
| 1045 |
+
if generation_config.logprob_threshold is not None:
|
| 1046 |
+
if hasattr(seek_outputs[0], "sequences_scores"):
|
| 1047 |
+
logprobs = [s["sequences_scores"] for s in seek_outputs][index]
|
| 1048 |
+
else:
|
| 1049 |
+
scores = seek_outputs[index]["scores"]
|
| 1050 |
+
logprobs = self._retrieve_avg_logprobs(
|
| 1051 |
+
scores, seek_sequence, generation_config.eos_token_id, temperature
|
| 1052 |
+
)
|
| 1053 |
+
|
| 1054 |
+
if logprobs < generation_config.logprob_threshold:
|
| 1055 |
+
needs_fallback = True
|
| 1056 |
+
|
| 1057 |
+
if generation_config.no_speech_threshold is not None:
|
| 1058 |
+
no_speech_prob = _get_attr_from_logit_processors(
|
| 1059 |
+
logits_processor, WhisperNoSpeechDetection, "no_speech_prob"
|
| 1060 |
+
)
|
| 1061 |
+
|
| 1062 |
+
if (
|
| 1063 |
+
logprobs < generation_config.logprob_threshold
|
| 1064 |
+
and no_speech_prob[index] > generation_config.no_speech_threshold
|
| 1065 |
+
):
|
| 1066 |
+
needs_fallback = False
|
| 1067 |
+
should_skip = True
|
| 1068 |
+
|
| 1069 |
+
return needs_fallback, should_skip
|
| 1070 |
+
|
| 1071 |
+
def _expand_variables_for_generation(
|
| 1072 |
+
self, input_features, seek, max_frames, init_tokens, batch_size, condition_on_prev_tokens, generation_config
|
| 1073 |
+
):
|
| 1074 |
+
if generation_config.num_return_sequences is not None and generation_config.num_return_sequences > 1:
|
| 1075 |
+
batch_idx_map = list(range(batch_size * generation_config.num_return_sequences))
|
| 1076 |
+
cur_bsz = len(batch_idx_map)
|
| 1077 |
+
do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(len(batch_idx_map))]
|
| 1078 |
+
input_features = input_features.repeat_interleave(generation_config.num_return_sequences, dim=0)
|
| 1079 |
+
seek = seek.repeat_interleave(generation_config.num_return_sequences, dim=0)
|
| 1080 |
+
max_frames = max_frames.repeat_interleave(generation_config.num_return_sequences, dim=0)
|
| 1081 |
+
init_tokens = init_tokens.repeat_interleave(generation_config.num_return_sequences, dim=0)
|
| 1082 |
+
generation_config.num_return_sequences = 1
|
| 1083 |
+
else:
|
| 1084 |
+
cur_bsz = batch_size
|
| 1085 |
+
batch_idx_map = list(range(cur_bsz))
|
| 1086 |
+
do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(cur_bsz)]
|
| 1087 |
+
|
| 1088 |
+
return (
|
| 1089 |
+
batch_idx_map,
|
| 1090 |
+
cur_bsz,
|
| 1091 |
+
input_features,
|
| 1092 |
+
seek,
|
| 1093 |
+
max_frames,
|
| 1094 |
+
init_tokens,
|
| 1095 |
+
do_condition_on_prev_tokens,
|
| 1096 |
+
)
|
| 1097 |
+
|
| 1098 |
+
@staticmethod
|
| 1099 |
+
def _setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs):
|
| 1100 |
+
set_inputs = _get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "set_inputs")
|
| 1101 |
+
extra_kwargs = {k: v for k, v in kwargs.items() if torch.is_tensor(v)}
|
| 1102 |
+
set_inputs({"inputs": segment_input, "decoder_input_ids": decoder_input_ids, **extra_kwargs})
|
| 1103 |
+
|
| 1104 |
+
@staticmethod
|
| 1105 |
+
def _retrieve_total_input_frames(input_features, input_stride, kwargs):
|
| 1106 |
+
if input_features is not None:
|
| 1107 |
+
return input_features.shape[0], input_features.shape[-1]
|
| 1108 |
+
|
| 1109 |
+
if "encoder_outputs" in kwargs:
|
| 1110 |
+
encoder_outputs_shape = (
|
| 1111 |
+
kwargs["encoder_outputs"][0].shape
|
| 1112 |
+
if isinstance(kwargs["encoder_outputs"], BaseModelOutput)
|
| 1113 |
+
else kwargs["encoder_outputs"].shape
|
| 1114 |
+
)
|
| 1115 |
+
return encoder_outputs_shape[0], encoder_outputs_shape[1] * input_stride
|
| 1116 |
+
|
| 1117 |
+
raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.")
|
| 1118 |
+
|
| 1119 |
+
@staticmethod
|
| 1120 |
+
def _maybe_warn_unused_inputs(
|
| 1121 |
+
condition_on_prev_tokens,
|
| 1122 |
+
temperature,
|
| 1123 |
+
compression_ratio_threshold,
|
| 1124 |
+
logprob_threshold,
|
| 1125 |
+
no_speech_threshold,
|
| 1126 |
+
total_input_frames,
|
| 1127 |
+
):
|
| 1128 |
+
warning_prefix = (
|
| 1129 |
+
f"Audio input consists of only {total_input_frames}. "
|
| 1130 |
+
"Short-form transcription is activated."
|
| 1131 |
+
"{}, but will be ignored."
|
| 1132 |
+
)
|
| 1133 |
+
if condition_on_prev_tokens is not None:
|
| 1134 |
+
logger.warning(warning_prefix.format(f"condition_on_prev_tokens is set to {condition_on_prev_tokens}"))
|
| 1135 |
+
|
| 1136 |
+
if compression_ratio_threshold is not None:
|
| 1137 |
+
logger.warning(
|
| 1138 |
+
warning_prefix.format(f"compression_ratio_threshold is set to {compression_ratio_threshold}")
|
| 1139 |
+
)
|
| 1140 |
+
|
| 1141 |
+
if logprob_threshold is not None:
|
| 1142 |
+
logger.warning(warning_prefix.format(f"logprob_threshold is set to {logprob_threshold}"))
|
| 1143 |
+
|
| 1144 |
+
if no_speech_threshold is not None:
|
| 1145 |
+
logger.warning(warning_prefix.format(f"no_speech_threshold is set to {no_speech_threshold}"))
|
| 1146 |
+
|
| 1147 |
+
# when passing temperature as a list it cannot just be ignored => throw error in this case
|
| 1148 |
+
if isinstance(temperature, (list, tuple)):
|
| 1149 |
+
raise ValueError(
|
| 1150 |
+
f"Audio input consists of only {total_input_frames}. Short-form transcription is activated."
|
| 1151 |
+
f"temperature cannot be set to {temperature} which can only be used for temperature fallback for long-form generation. Make sure to set `temperature` to a float value or `None` for short-form generation."
|
| 1152 |
+
)
|
| 1153 |
+
|
| 1154 |
+
@staticmethod
|
| 1155 |
+
def _set_return_outputs(return_dict_in_generate, return_token_timestamps, logprob_threshold, generation_config):
|
| 1156 |
+
if return_dict_in_generate is None:
|
| 1157 |
+
return_dict_in_generate = generation_config.return_dict_in_generate
|
| 1158 |
+
else:
|
| 1159 |
+
generation_config.return_dict_in_generate = return_dict_in_generate
|
| 1160 |
+
|
| 1161 |
+
generation_config.return_token_timestamps = return_token_timestamps
|
| 1162 |
+
if return_token_timestamps:
|
| 1163 |
+
generation_config.return_dict_in_generate = True
|
| 1164 |
+
generation_config.output_attentions = True
|
| 1165 |
+
generation_config.output_scores = True
|
| 1166 |
+
|
| 1167 |
+
if logprob_threshold is not None:
|
| 1168 |
+
generation_config.return_dict_in_generate = True
|
| 1169 |
+
generation_config.output_scores = True
|
| 1170 |
+
|
| 1171 |
+
return return_dict_in_generate
|
| 1172 |
+
|
| 1173 |
+
def _set_return_timestamps(self, return_timestamps, is_shortform, generation_config):
|
| 1174 |
+
if return_timestamps is None and hasattr(generation_config, "return_timestamps"):
|
| 1175 |
+
return_timestamps = generation_config.return_timestamps
|
| 1176 |
+
|
| 1177 |
+
if not is_shortform:
|
| 1178 |
+
if return_timestamps is False:
|
| 1179 |
+
raise ValueError(
|
| 1180 |
+
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
|
| 1181 |
+
"requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features."
|
| 1182 |
+
)
|
| 1183 |
+
|
| 1184 |
+
logger.info("Setting `return_timestamps=True` for long-form generation.")
|
| 1185 |
+
return_timestamps = True
|
| 1186 |
+
|
| 1187 |
+
if return_timestamps and not hasattr(generation_config, "no_timestamps_token_id"):
|
| 1188 |
+
raise ValueError(
|
| 1189 |
+
"You are trying to return timestamps, but the generation config is not properly set. "
|
| 1190 |
+
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
|
| 1191 |
+
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
|
| 1192 |
+
)
|
| 1193 |
+
|
| 1194 |
+
generation_config.return_timestamps = return_timestamps
|
| 1195 |
+
|
| 1196 |
+
if hasattr(generation_config, "no_timestamps_token_id"):
|
| 1197 |
+
timestamp_begin = generation_config.no_timestamps_token_id + 1
|
| 1198 |
+
else:
|
| 1199 |
+
# BC for models missing the `no_timestamps_token_id` in the generation config when generating short-form with no timestamps
|
| 1200 |
+
# We set the timestamp begin token larger than the vocab size, such that the timestamp condition is never met in the decoding loop
|
| 1201 |
+
timestamp_begin = self.config.vocab_size + 1
|
| 1202 |
+
|
| 1203 |
+
return timestamp_begin
|
| 1204 |
+
|
| 1205 |
+
@staticmethod
|
| 1206 |
+
def _set_language_and_task(language, task, is_multilingual, generation_config):
|
| 1207 |
+
if is_multilingual is not None:
|
| 1208 |
+
if not hasattr(generation_config, "is_multilingual"):
|
| 1209 |
+
raise ValueError(
|
| 1210 |
+
"The generation config is outdated and is thus not compatible with the `is_multilingual` argument "
|
| 1211 |
+
"to `generate`. Please update the generation config as per the instructions "
|
| 1212 |
+
"https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
|
| 1213 |
+
)
|
| 1214 |
+
generation_config.is_multilingual = is_multilingual
|
| 1215 |
+
|
| 1216 |
+
if hasattr(generation_config, "is_multilingual") and not generation_config.is_multilingual:
|
| 1217 |
+
if task is not None or language is not None:
|
| 1218 |
+
raise ValueError(
|
| 1219 |
+
"Cannot specify `task` or `language` for an English-only model. If the model is intended to be "
|
| 1220 |
+
"multilingual, pass `is_multilingual=True` to generate, or update the generation config."
|
| 1221 |
+
)
|
| 1222 |
+
|
| 1223 |
+
if language is not None:
|
| 1224 |
+
if not hasattr(generation_config, "lang_to_id"):
|
| 1225 |
+
raise ValueError(
|
| 1226 |
+
"The generation config is outdated and is thus not compatible with the `language` argument "
|
| 1227 |
+
"to `generate`. Either set the language using the `forced_decoder_ids` in the model config, "
|
| 1228 |
+
"or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
|
| 1229 |
+
)
|
| 1230 |
+
generation_config.language = language
|
| 1231 |
+
|
| 1232 |
+
if task is not None:
|
| 1233 |
+
if not hasattr(generation_config, "task_to_id"):
|
| 1234 |
+
raise ValueError(
|
| 1235 |
+
"The generation config is outdated and is thus not compatible with the `task` argument "
|
| 1236 |
+
"to `generate`. Either set the task using the `forced_decoder_ids` in the model config, "
|
| 1237 |
+
"or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
|
| 1238 |
+
)
|
| 1239 |
+
generation_config.task = task
|
| 1240 |
+
|
| 1241 |
+
def _retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs):
|
| 1242 |
+
def replace_or_add(lst: List[int], num: int, itr: Iterator[int]):
|
| 1243 |
+
"""short function to replace num with a itr in lst"""
|
| 1244 |
+
found = any(i in lst for i in itr)
|
| 1245 |
+
if found:
|
| 1246 |
+
lst = [num if i in itr else i for i in lst]
|
| 1247 |
+
else:
|
| 1248 |
+
lst.append(num)
|
| 1249 |
+
return lst
|
| 1250 |
+
|
| 1251 |
+
def language_to_id(language: str) -> int:
|
| 1252 |
+
language = language.lower()
|
| 1253 |
+
if language in generation_config.lang_to_id.keys():
|
| 1254 |
+
language_token = language
|
| 1255 |
+
elif language in TO_LANGUAGE_CODE.keys():
|
| 1256 |
+
language_token = f"<|{TO_LANGUAGE_CODE[language]}|>"
|
| 1257 |
+
elif language in TO_LANGUAGE_CODE.values():
|
| 1258 |
+
language_token = f"<|{language}|>"
|
| 1259 |
+
else:
|
| 1260 |
+
is_language_code = len(language) == 2
|
| 1261 |
+
raise ValueError(
|
| 1262 |
+
f"Unsupported language: {language}. Language should be one of:"
|
| 1263 |
+
f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
|
| 1264 |
+
)
|
| 1265 |
+
if language_token not in generation_config.lang_to_id:
|
| 1266 |
+
raise ValueError(
|
| 1267 |
+
f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`."
|
| 1268 |
+
"(You should just add it to the generation config)"
|
| 1269 |
+
)
|
| 1270 |
+
|
| 1271 |
+
return generation_config.lang_to_id[language_token]
|
| 1272 |
+
|
| 1273 |
+
task = getattr(generation_config, "task", None)
|
| 1274 |
+
language = getattr(generation_config, "language", None)
|
| 1275 |
+
|
| 1276 |
+
forced_decoder_ids = generation_config.forced_decoder_ids
|
| 1277 |
+
if forced_decoder_ids is not None:
|
| 1278 |
+
if language is None and task is None and forced_decoder_ids[0][1] is None:
|
| 1279 |
+
logger.warning_once(
|
| 1280 |
+
"Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English."
|
| 1281 |
+
"This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`."
|
| 1282 |
+
)
|
| 1283 |
+
elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
|
| 1284 |
+
forced_decoder_ids = config.forced_decoder_ids
|
| 1285 |
+
|
| 1286 |
+
if forced_decoder_ids is not None and task is not None:
|
| 1287 |
+
logger.warning_once(
|
| 1288 |
+
f"You have passed task={task}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of task={task}."
|
| 1289 |
+
)
|
| 1290 |
+
forced_decoder_ids = None
|
| 1291 |
+
elif forced_decoder_ids is not None and language is not None:
|
| 1292 |
+
logger.warning_once(
|
| 1293 |
+
f"You have passed language={language}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of language={language}."
|
| 1294 |
+
)
|
| 1295 |
+
forced_decoder_ids = None
|
| 1296 |
+
|
| 1297 |
+
init_tokens = [generation_config.decoder_start_token_id]
|
| 1298 |
+
if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1:
|
| 1299 |
+
i = 1
|
| 1300 |
+
while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i:
|
| 1301 |
+
init_tokens += [forced_decoder_ids[0][1]]
|
| 1302 |
+
forced_decoder_ids = forced_decoder_ids[1:]
|
| 1303 |
+
i += 1
|
| 1304 |
+
|
| 1305 |
+
if len(forced_decoder_ids) > 0:
|
| 1306 |
+
raise ValueError(
|
| 1307 |
+
f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all indices >= 1 and < {forced_decoder_ids[0][0]}.",
|
| 1308 |
+
)
|
| 1309 |
+
|
| 1310 |
+
# from v4.39 the forced decoder ids are always None in favour of decoder input ids
|
| 1311 |
+
generation_config.forced_decoder_ids = None
|
| 1312 |
+
|
| 1313 |
+
is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
|
| 1314 |
+
|
| 1315 |
+
# Make sure language is a list of strings of the correct length
|
| 1316 |
+
if isinstance(language, (list, tuple)):
|
| 1317 |
+
if any(l is None for l in language):
|
| 1318 |
+
raise TypeError(
|
| 1319 |
+
"Expected `language` to be `None`, a single string (e.g. `'en'`), or a list of strings with length equal to the batch size (e.g. `('en', 'fr')` for a batch size of 2). Got a list containing `None`."
|
| 1320 |
+
)
|
| 1321 |
+
if len(language) != batch_size:
|
| 1322 |
+
raise ValueError(
|
| 1323 |
+
"When passing a list of languages, the length of the list must match the batch size. "
|
| 1324 |
+
f"Expected length of {batch_size}, but got {len(language)} languages."
|
| 1325 |
+
)
|
| 1326 |
+
languages = language
|
| 1327 |
+
elif language is None:
|
| 1328 |
+
# Language will be detected for each item in batch
|
| 1329 |
+
languages = [None] * batch_size
|
| 1330 |
+
else:
|
| 1331 |
+
languages = [language] # Use a length-1 list now, broadcast later
|
| 1332 |
+
|
| 1333 |
+
# Separate init_tokens for each language
|
| 1334 |
+
init_tokens = [copy.copy(init_tokens) for _ in languages]
|
| 1335 |
+
|
| 1336 |
+
# Update init_tokens with languages
|
| 1337 |
+
lang_ids = None
|
| 1338 |
+
if language is not None:
|
| 1339 |
+
lang_ids = [language_to_id(l) for l in languages]
|
| 1340 |
+
elif hasattr(generation_config, "lang_to_id") and is_lang_id_undefined:
|
| 1341 |
+
# language is not defined or intentially set to `None` to trigger language detection
|
| 1342 |
+
lang_ids = self.detect_language(
|
| 1343 |
+
input_features=input_features,
|
| 1344 |
+
encoder_outputs=kwargs.get("encoder_outputs", None),
|
| 1345 |
+
attention_mask=kwargs.get("attention_mask", None),
|
| 1346 |
+
generation_config=generation_config,
|
| 1347 |
+
num_segment_frames=num_segment_frames,
|
| 1348 |
+
).tolist()
|
| 1349 |
+
if lang_ids is not None:
|
| 1350 |
+
# append or replace lang_ids to init_tokens
|
| 1351 |
+
for i in range(len(init_tokens)):
|
| 1352 |
+
if len(init_tokens[i]) > 1:
|
| 1353 |
+
init_tokens[i][1] = lang_ids[i]
|
| 1354 |
+
else:
|
| 1355 |
+
init_tokens[i].append(lang_ids[i])
|
| 1356 |
+
del languages
|
| 1357 |
+
|
| 1358 |
+
# Update init_tokens with task
|
| 1359 |
+
for i in range(len(init_tokens)):
|
| 1360 |
+
if task is not None:
|
| 1361 |
+
if task in TASK_IDS:
|
| 1362 |
+
init_tokens[i].append(generation_config.task_to_id[generation_config.task])
|
| 1363 |
+
task_id = generation_config.task_to_id[generation_config.task]
|
| 1364 |
+
|
| 1365 |
+
# if task is defined it'll overwrite task ids that might have already been defined via the generation_config
|
| 1366 |
+
replace_or_add(init_tokens[i], task_id, generation_config.task_to_id.values())
|
| 1367 |
+
else:
|
| 1368 |
+
raise ValueError(f"The `{task}`task is not supported. The task should be one of `{TASK_IDS}`")
|
| 1369 |
+
elif language is not None and hasattr(generation_config, "task_to_id"):
|
| 1370 |
+
# if language is defined, but no task id is in `init_tokens`, default to transcribe
|
| 1371 |
+
if not any(ti in init_tokens[i] for ti in generation_config.task_to_id.values()):
|
| 1372 |
+
init_tokens[i].append(generation_config.task_to_id["transcribe"])
|
| 1373 |
+
|
| 1374 |
+
if (
|
| 1375 |
+
not generation_config.return_timestamps
|
| 1376 |
+
and hasattr(generation_config, "no_timestamps_token_id")
|
| 1377 |
+
and init_tokens[i][-1] != generation_config.no_timestamps_token_id
|
| 1378 |
+
):
|
| 1379 |
+
init_tokens[i].append(generation_config.no_timestamps_token_id)
|
| 1380 |
+
elif (
|
| 1381 |
+
generation_config.return_timestamps and init_tokens[i][-1] == generation_config.no_timestamps_token_id
|
| 1382 |
+
):
|
| 1383 |
+
logger.info(
|
| 1384 |
+
"<|notimestamps|> prompt token is removed from generation_config since `return_timestamps` is set to `'True'`."
|
| 1385 |
+
)
|
| 1386 |
+
init_tokens[i] = init_tokens[i][:-1]
|
| 1387 |
+
|
| 1388 |
+
# let's make sure we don't pass `None` tokens as prompt tokens
|
| 1389 |
+
init_tokens[i] = [t for t in init_tokens[i] if t is not None]
|
| 1390 |
+
|
| 1391 |
+
return torch.as_tensor(init_tokens, dtype=torch.long, device=self.device).expand(batch_size, -1)
|
| 1392 |
+
|
| 1393 |
+
def detect_language(
|
| 1394 |
+
self,
|
| 1395 |
+
input_features: Optional[torch.FloatTensor] = None,
|
| 1396 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1397 |
+
encoder_outputs: Optional[Union[torch.FloatTensor, BaseModelOutput]] = None,
|
| 1398 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 1399 |
+
num_segment_frames: int = 3000,
|
| 1400 |
+
) -> torch.Tensor:
|
| 1401 |
+
"""
|
| 1402 |
+
Detects language from log-mel input features or encoder_outputs
|
| 1403 |
+
|
| 1404 |
+
Parameters:
|
| 1405 |
+
input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):
|
| 1406 |
+
Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by
|
| 1407 |
+
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
|
| 1408 |
+
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
|
| 1409 |
+
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
|
| 1410 |
+
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.
|
| 1411 |
+
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
|
| 1412 |
+
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
| 1413 |
+
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
| 1414 |
+
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
| 1415 |
+
generation_config (`~generation.GenerationConfig`, *optional*):
|
| 1416 |
+
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
|
| 1417 |
+
passed to generate matching the attributes of `generation_config` will override them. If
|
| 1418 |
+
`generation_config` is not provided, the default will be used, which had the following loading
|
| 1419 |
+
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
| 1420 |
+
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
| 1421 |
+
default values, whose documentation should be checked to parameterize generation.
|
| 1422 |
+
num_segment_frames (`int`, *optional*, defaults to 3000):
|
| 1423 |
+
The number of log-mel frames the model expects
|
| 1424 |
+
|
| 1425 |
+
Return:
|
| 1426 |
+
A `torch.LongTensor` representing the detected language ids.
|
| 1427 |
+
"""
|
| 1428 |
+
if input_features is None and encoder_outputs is None:
|
| 1429 |
+
raise ValueError("You have to specify either `input_features` or `encoder_outputs`")
|
| 1430 |
+
elif input_features is not None and encoder_outputs is not None:
|
| 1431 |
+
raise ValueError("Make sure to specificy only one of `input_features` or `encoder_outputs` - not both!")
|
| 1432 |
+
elif input_features is not None:
|
| 1433 |
+
inputs = {"input_features": input_features[:, :, :num_segment_frames]}
|
| 1434 |
+
batch_size = input_features.shape[0]
|
| 1435 |
+
elif encoder_outputs is not None:
|
| 1436 |
+
inputs = {"encoder_outputs": encoder_outputs}
|
| 1437 |
+
batch_size = (
|
| 1438 |
+
encoder_outputs[0].shape[0] if isinstance(encoder_outputs, BaseModelOutput) else encoder_outputs[0]
|
| 1439 |
+
)
|
| 1440 |
+
if attention_mask is not None:
|
| 1441 |
+
inputs["attention_mask"] = attention_mask
|
| 1442 |
+
|
| 1443 |
+
generation_config = generation_config or self.generation_config
|
| 1444 |
+
decoder_input_ids = (
|
| 1445 |
+
torch.ones((batch_size, 1), device=self.device, dtype=torch.long)
|
| 1446 |
+
* generation_config.decoder_start_token_id
|
| 1447 |
+
)
|
| 1448 |
+
|
| 1449 |
+
with torch.no_grad():
|
| 1450 |
+
logits = self(**inputs, decoder_input_ids=decoder_input_ids).logits[:, -1]
|
| 1451 |
+
|
| 1452 |
+
non_lang_mask = torch.ones_like(logits[0], dtype=torch.bool)
|
| 1453 |
+
non_lang_mask[list(generation_config.lang_to_id.values())] = False
|
| 1454 |
+
|
| 1455 |
+
logits[:, non_lang_mask] = -np.inf
|
| 1456 |
+
|
| 1457 |
+
lang_ids = logits.argmax(-1)
|
| 1458 |
+
|
| 1459 |
+
return lang_ids
|
| 1460 |
+
|
| 1461 |
+
@staticmethod
|
| 1462 |
+
def _check_decoder_input_ids(kwargs):
|
| 1463 |
+
decoder_input_ids = kwargs.get("decoder_input_ids", None)
|
| 1464 |
+
assistant_model = kwargs.get("assistant_model", None)
|
| 1465 |
+
if decoder_input_ids is not None and assistant_model is not None:
|
| 1466 |
+
raise ValueError(
|
| 1467 |
+
"Passing `decoder_input_ids` is deprecated. Consider passing `prompt_ids` instead.",
|
| 1468 |
+
)
|
| 1469 |
+
|
| 1470 |
+
@staticmethod
|
| 1471 |
+
def _set_num_frames(return_token_timestamps, generation_config, kwargs):
|
| 1472 |
+
if return_token_timestamps:
|
| 1473 |
+
if getattr(generation_config, "task", None) == "translate":
|
| 1474 |
+
logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
|
| 1475 |
+
if not hasattr(generation_config, "alignment_heads"):
|
| 1476 |
+
raise ValueError(
|
| 1477 |
+
"Model generation config has no `alignment_heads`, token-level timestamps not available. "
|
| 1478 |
+
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
|
| 1479 |
+
)
|
| 1480 |
+
generation_config.num_frames = kwargs.pop("num_frames", None)
|
| 1481 |
+
|
| 1482 |
+
@staticmethod
|
| 1483 |
+
def _set_thresholds_and_condition(
|
| 1484 |
+
generation_config,
|
| 1485 |
+
logprob_threshold,
|
| 1486 |
+
compression_ratio_threshold,
|
| 1487 |
+
no_speech_threshold,
|
| 1488 |
+
condition_on_prev_tokens,
|
| 1489 |
+
):
|
| 1490 |
+
generation_config.logprob_threshold = (
|
| 1491 |
+
logprob_threshold
|
| 1492 |
+
if logprob_threshold is not None
|
| 1493 |
+
else getattr(generation_config, "logprob_threshold", None)
|
| 1494 |
+
)
|
| 1495 |
+
generation_config.compression_ratio_threshold = (
|
| 1496 |
+
compression_ratio_threshold
|
| 1497 |
+
if compression_ratio_threshold is not None
|
| 1498 |
+
else getattr(generation_config, "compression_ratio_threshold", None)
|
| 1499 |
+
)
|
| 1500 |
+
generation_config.no_speech_threshold = (
|
| 1501 |
+
no_speech_threshold
|
| 1502 |
+
if no_speech_threshold is not None
|
| 1503 |
+
else getattr(generation_config, "no_speech_threshold", None)
|
| 1504 |
+
)
|
| 1505 |
+
generation_config.condition_on_prev_tokens = (
|
| 1506 |
+
condition_on_prev_tokens
|
| 1507 |
+
if condition_on_prev_tokens is not None
|
| 1508 |
+
else getattr(generation_config, "condition_on_prev_tokens", None)
|
| 1509 |
+
)
|
| 1510 |
+
|
| 1511 |
+
@staticmethod
|
| 1512 |
+
def _set_prompt_condition_type(generation_config, prompt_condition_type):
|
| 1513 |
+
allowed_cond_types = ["first-segment", "all-segments"]
|
| 1514 |
+
|
| 1515 |
+
# default to "first-segment"
|
| 1516 |
+
prompt_condition_type = prompt_condition_type or allowed_cond_types[0]
|
| 1517 |
+
|
| 1518 |
+
if prompt_condition_type not in allowed_cond_types:
|
| 1519 |
+
raise ValueError(
|
| 1520 |
+
f"`prompt_condition_type={prompt_condition_type} does not exist. Make sure to set `prompt_condition_type` to one of {', '.join(allowed_cond_types)}"
|
| 1521 |
+
)
|
| 1522 |
+
|
| 1523 |
+
if generation_config.condition_on_prev_tokens is not True and prompt_condition_type == "all-segments":
|
| 1524 |
+
raise ValueError(
|
| 1525 |
+
"Make sure to set `condition_on_prev_tokens=True` when setting `prompt_condition_type='all-segments'`."
|
| 1526 |
+
)
|
| 1527 |
+
|
| 1528 |
+
generation_config.prompt_condition_type = prompt_condition_type
|
| 1529 |
+
|
| 1530 |
+
@staticmethod
|
| 1531 |
+
def _set_condition_on_prev_tokens(condition_on_prev_tokens, generation_config):
|
| 1532 |
+
condition_on_prev_tokens = (
|
| 1533 |
+
condition_on_prev_tokens
|
| 1534 |
+
if condition_on_prev_tokens is not None
|
| 1535 |
+
else getattr(generation_config, "condition_on_prev_tokens", False)
|
| 1536 |
+
)
|
| 1537 |
+
generation_config.condition_on_prev_tokens = condition_on_prev_tokens
|
| 1538 |
+
|
| 1539 |
+
@staticmethod
|
| 1540 |
+
def _retrieve_max_frames_and_seek(batch_size, attention_mask, total_input_frames, is_shortform):
|
| 1541 |
+
if batch_size > 1 and not is_shortform and attention_mask is None:
|
| 1542 |
+
raise ValueError(
|
| 1543 |
+
"When doing batched long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` "
|
| 1544 |
+
)
|
| 1545 |
+
elif batch_size > 1 and not is_shortform:
|
| 1546 |
+
max_frames = attention_mask.sum(-1).cpu().to(torch.long)
|
| 1547 |
+
seek = torch.zeros((batch_size,), dtype=torch.long)
|
| 1548 |
+
else:
|
| 1549 |
+
max_frames = torch.ones((batch_size,), dtype=torch.long) * total_input_frames
|
| 1550 |
+
seek = torch.zeros((batch_size,), dtype=torch.long)
|
| 1551 |
+
|
| 1552 |
+
return max_frames, seek
|
| 1553 |
+
|
| 1554 |
+
def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, num_beams, device):
|
| 1555 |
+
if generation_config.return_timestamps is True:
|
| 1556 |
+
timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index)
|
| 1557 |
+
logits_processor = (
|
| 1558 |
+
[timestamp_processor] if logits_processor is None else [timestamp_processor] + logits_processor
|
| 1559 |
+
)
|
| 1560 |
+
|
| 1561 |
+
if generation_config.suppress_tokens is not None:
|
| 1562 |
+
suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens, device=device)
|
| 1563 |
+
logits_processor = (
|
| 1564 |
+
[suppress_tokens_processor]
|
| 1565 |
+
if logits_processor is None
|
| 1566 |
+
else [suppress_tokens_processor] + logits_processor
|
| 1567 |
+
)
|
| 1568 |
+
generation_config.suppress_tokens = None
|
| 1569 |
+
|
| 1570 |
+
if generation_config.begin_suppress_tokens is not None:
|
| 1571 |
+
begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor(
|
| 1572 |
+
generation_config.begin_suppress_tokens, begin_index=begin_index, device=device
|
| 1573 |
+
)
|
| 1574 |
+
logits_processor = (
|
| 1575 |
+
[begin_suppress_processor]
|
| 1576 |
+
if logits_processor is None
|
| 1577 |
+
else [begin_suppress_processor] + logits_processor
|
| 1578 |
+
)
|
| 1579 |
+
generation_config.begin_suppress_tokens = None
|
| 1580 |
+
|
| 1581 |
+
if generation_config.no_speech_threshold is not None:
|
| 1582 |
+
no_speech_detector = WhisperNoSpeechDetection(
|
| 1583 |
+
no_speech_token=generation_config.no_timestamps_token_id - 1,
|
| 1584 |
+
begin_index=begin_index,
|
| 1585 |
+
scores_is_logprobs=num_beams > 1,
|
| 1586 |
+
)
|
| 1587 |
+
logits_processor = (
|
| 1588 |
+
[no_speech_detector] if logits_processor is None else [no_speech_detector] + logits_processor
|
| 1589 |
+
)
|
| 1590 |
+
no_speech_detector.set_model(self)
|
| 1591 |
+
|
| 1592 |
+
return logits_processor
|
| 1593 |
+
|
| 1594 |
+
@staticmethod
|
| 1595 |
+
def _maybe_reduce_batch(input_features, seek, max_frames, cur_bsz, batch_idx_map):
|
| 1596 |
+
prev_bsz = cur_bsz
|
| 1597 |
+
new_batch_idx_map = []
|
| 1598 |
+
for i in range(prev_bsz):
|
| 1599 |
+
prev_i = batch_idx_map[i]
|
| 1600 |
+
if seek[prev_i] >= max_frames[prev_i]:
|
| 1601 |
+
cut_index = i + (cur_bsz - prev_bsz)
|
| 1602 |
+
cur_bsz -= 1
|
| 1603 |
+
input_features = torch.cat([input_features[:cut_index], input_features[cut_index + 1 :]], dim=0)
|
| 1604 |
+
else:
|
| 1605 |
+
# cut out index that goes away
|
| 1606 |
+
new_batch_idx_map.append(prev_i)
|
| 1607 |
+
|
| 1608 |
+
return input_features, cur_bsz, new_batch_idx_map
|
| 1609 |
+
|
| 1610 |
+
@staticmethod
|
| 1611 |
+
def _get_input_segment(input_features, seek, seek_num_frames, num_segment_frames, cur_bsz, batch_idx_map):
|
| 1612 |
+
if input_features is None:
|
| 1613 |
+
return None
|
| 1614 |
+
|
| 1615 |
+
segment_input = []
|
| 1616 |
+
for i in range(cur_bsz):
|
| 1617 |
+
prev_i = batch_idx_map[i]
|
| 1618 |
+
segment_input_slice = input_features[i : i + 1, :, seek[prev_i] : seek[prev_i] + seek_num_frames[prev_i]]
|
| 1619 |
+
|
| 1620 |
+
if segment_input_slice.shape[-1] < num_segment_frames:
|
| 1621 |
+
# pad to 3000 if necessary
|
| 1622 |
+
segment_input_slice = F.pad(
|
| 1623 |
+
segment_input_slice, pad=(0, num_segment_frames - segment_input_slice.shape[-1])
|
| 1624 |
+
)
|
| 1625 |
+
|
| 1626 |
+
segment_input.append(segment_input_slice)
|
| 1627 |
+
|
| 1628 |
+
segment_input = torch.cat(segment_input, dim=0)
|
| 1629 |
+
|
| 1630 |
+
return segment_input
|
| 1631 |
+
|
| 1632 |
+
@staticmethod
|
| 1633 |
+
def _prepare_decoder_input_ids(
|
| 1634 |
+
cur_bsz,
|
| 1635 |
+
init_tokens,
|
| 1636 |
+
current_segments,
|
| 1637 |
+
batch_idx_map,
|
| 1638 |
+
do_condition_on_prev_tokens,
|
| 1639 |
+
prompt_ids,
|
| 1640 |
+
generation_config,
|
| 1641 |
+
config,
|
| 1642 |
+
device,
|
| 1643 |
+
suppress_tokens,
|
| 1644 |
+
kwargs,
|
| 1645 |
+
):
|
| 1646 |
+
if "decoder_input_ids" in kwargs:
|
| 1647 |
+
decoder_input_ids = kwargs.pop("decoder_input_ids")
|
| 1648 |
+
|
| 1649 |
+
return decoder_input_ids, kwargs
|
| 1650 |
+
|
| 1651 |
+
cut_off_length = config.max_target_positions // 2 - 1
|
| 1652 |
+
|
| 1653 |
+
decoder_input_ids = init_tokens[batch_idx_map]
|
| 1654 |
+
|
| 1655 |
+
prev_start_of_text = getattr(generation_config, "prev_sot_token_id", None)
|
| 1656 |
+
if prev_start_of_text is None:
|
| 1657 |
+
prev_start_of_text = suppress_tokens[-2] if suppress_tokens is not None else None
|
| 1658 |
+
|
| 1659 |
+
if any(do_condition_on_prev_tokens) and len(current_segments[0]) > 0:
|
| 1660 |
+
# according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609
|
| 1661 |
+
active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map]
|
| 1662 |
+
|
| 1663 |
+
if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments":
|
| 1664 |
+
prev_ids = prompt_ids
|
| 1665 |
+
else:
|
| 1666 |
+
one_tensor = torch.ones((cur_bsz, 1), device=device, dtype=torch.long)
|
| 1667 |
+
prev_ids = prev_start_of_text * one_tensor[0] if prev_start_of_text is not None else None
|
| 1668 |
+
|
| 1669 |
+
padding = "max_length" if generation_config.cache_implementation == "static" else "longest"
|
| 1670 |
+
|
| 1671 |
+
prev_tokens = _pad_to_max_length(
|
| 1672 |
+
active_segments,
|
| 1673 |
+
generation_config.pad_token_id,
|
| 1674 |
+
device=device,
|
| 1675 |
+
padding_side="left",
|
| 1676 |
+
padding=padding,
|
| 1677 |
+
bos_token_tensor=prev_ids,
|
| 1678 |
+
cut_off_length=cut_off_length,
|
| 1679 |
+
)
|
| 1680 |
+
decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
|
| 1681 |
+
|
| 1682 |
+
kwargs["decoder_attention_mask"] = decoder_input_ids != generation_config.pad_token_id
|
| 1683 |
+
elif prompt_ids is not None:
|
| 1684 |
+
prev_tokens = prompt_ids[None].repeat(decoder_input_ids.shape[0], 1)
|
| 1685 |
+
decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
|
| 1686 |
+
# make sure `"decoder_attention_mask"` is not passed to forward
|
| 1687 |
+
kwargs.pop("decoder_attention_mask", None)
|
| 1688 |
+
else:
|
| 1689 |
+
# make sure `"decoder_attention_mask"` is not passed to forward
|
| 1690 |
+
kwargs.pop("decoder_attention_mask", None)
|
| 1691 |
+
|
| 1692 |
+
return decoder_input_ids, kwargs
|
| 1693 |
+
|
| 1694 |
+
def _set_max_new_tokens_and_length(self, config, decoder_input_ids, generation_config):
|
| 1695 |
+
max_new_tokens = generation_config.max_new_tokens if generation_config.max_new_tokens is not None else 0
|
| 1696 |
+
if max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions:
|
| 1697 |
+
raise ValueError(
|
| 1698 |
+
f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` "
|
| 1699 |
+
f"is {max_new_tokens}. Thus, the combined length of "
|
| 1700 |
+
f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the "
|
| 1701 |
+
f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. "
|
| 1702 |
+
"You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "
|
| 1703 |
+
f"so that their combined length is less than {self.config.max_target_positions}."
|
| 1704 |
+
)
|
| 1705 |
+
|
| 1706 |
+
num_initial_tokens = min(config.max_target_positions // 2 - 1, decoder_input_ids.shape[-1] - 1)
|
| 1707 |
+
|
| 1708 |
+
# Make sure we don't get larger than `max_length`
|
| 1709 |
+
if generation_config.max_length is not None and generation_config.max_new_tokens is None:
|
| 1710 |
+
max_length = min(generation_config.max_length + num_initial_tokens, config.max_target_positions)
|
| 1711 |
+
logger.info(
|
| 1712 |
+
f"Increase max_length from {generation_config.max_length} to {max_length} since input is conditioned on previous segment."
|
| 1713 |
+
)
|
| 1714 |
+
elif (
|
| 1715 |
+
generation_config.max_new_tokens is not None
|
| 1716 |
+
and generation_config.max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions
|
| 1717 |
+
):
|
| 1718 |
+
max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1]
|
| 1719 |
+
generation_config.max_new_tokens = max_new_tokens
|
| 1720 |
+
|
| 1721 |
+
@staticmethod
|
| 1722 |
+
def _retrieve_compression_ratio(tokens, vocab_size):
|
| 1723 |
+
"""Compute byte length of zlib compressed token bytes vs. byte length of raw token bytes"""
|
| 1724 |
+
length = int(math.log2(vocab_size) / 8) + 1
|
| 1725 |
+
token_bytes = b"".join([t.to_bytes(length, "little") for t in tokens.tolist()])
|
| 1726 |
+
compression_ratio = len(token_bytes) / len(zlib.compress(token_bytes))
|
| 1727 |
+
|
| 1728 |
+
return compression_ratio
|
| 1729 |
+
|
| 1730 |
+
@staticmethod
|
| 1731 |
+
def _retrieve_avg_logprobs(scores, tokens, eos_token_id, temperature):
|
| 1732 |
+
rescale_temperature = temperature if temperature > 0.0 else 1
|
| 1733 |
+
scores = torch.stack(scores).to(tokens.device)
|
| 1734 |
+
|
| 1735 |
+
if scores.shape[0] > tokens.shape[0]:
|
| 1736 |
+
scores = scores[: tokens.shape[0]]
|
| 1737 |
+
else:
|
| 1738 |
+
tokens = tokens[-scores.shape[0] :]
|
| 1739 |
+
|
| 1740 |
+
logprobs = F.log_softmax((scores * rescale_temperature).float(), dim=-1).to(scores.dtype)
|
| 1741 |
+
|
| 1742 |
+
# retrieve logprob of selected tokens and sum
|
| 1743 |
+
sum_logprobs = sum((logprobs[i][tokens[i]] * (tokens[i] != eos_token_id)) for i in range(logprobs.shape[0]))
|
| 1744 |
+
length = (tokens != eos_token_id).sum(-1) if eos_token_id is not None else tokens.shape[0]
|
| 1745 |
+
|
| 1746 |
+
avg_logprobs = sum_logprobs / (length + 1)
|
| 1747 |
+
return avg_logprobs
|
| 1748 |
+
|
| 1749 |
+
@staticmethod
|
| 1750 |
+
def _retrieve_segment(
|
| 1751 |
+
seek_sequence,
|
| 1752 |
+
seek_outputs,
|
| 1753 |
+
time_offset,
|
| 1754 |
+
timestamp_begin,
|
| 1755 |
+
seek_num_frames,
|
| 1756 |
+
time_precision,
|
| 1757 |
+
input_stride,
|
| 1758 |
+
prev_idx,
|
| 1759 |
+
idx,
|
| 1760 |
+
return_token_timestamps,
|
| 1761 |
+
):
|
| 1762 |
+
# find the predicted "end of segment" predictions of Whisper
|
| 1763 |
+
# "end of segment" predictions occur whenever Whisper predicts a timestamp token
|
| 1764 |
+
timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
|
| 1765 |
+
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
| 1766 |
+
timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
| 1767 |
+
timestamp_segment_indices.add_(1)
|
| 1768 |
+
token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
|
| 1769 |
+
|
| 1770 |
+
# If whisper predicted a "end of segment" via a timestep token, let's go ever each
|
| 1771 |
+
# "end of segment" prediction and slice the decoding into segments accordingly
|
| 1772 |
+
if len(timestamp_segment_indices) > 0:
|
| 1773 |
+
# if the output contains two consecutive timestamp tokens
|
| 1774 |
+
slices = timestamp_segment_indices.tolist()
|
| 1775 |
+
segments = []
|
| 1776 |
+
if single_timestamp_ending:
|
| 1777 |
+
slices.append(len(seek_sequence))
|
| 1778 |
+
|
| 1779 |
+
last_slice = 0
|
| 1780 |
+
# Add each segment to list of all segments
|
| 1781 |
+
for current_slice in slices:
|
| 1782 |
+
sliced_tokens = seek_sequence[last_slice:current_slice]
|
| 1783 |
+
start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin
|
| 1784 |
+
end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin
|
| 1785 |
+
segments.append(
|
| 1786 |
+
{
|
| 1787 |
+
"start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
|
| 1788 |
+
"end": time_offset[prev_idx] + end_timestamp_pos * time_precision,
|
| 1789 |
+
"tokens": sliced_tokens,
|
| 1790 |
+
"result": seek_outputs[idx],
|
| 1791 |
+
}
|
| 1792 |
+
)
|
| 1793 |
+
if return_token_timestamps:
|
| 1794 |
+
segments[-1]["token_timestamps"] = (
|
| 1795 |
+
token_timestamps[last_slice:current_slice] + time_offset[prev_idx]
|
| 1796 |
+
)
|
| 1797 |
+
last_slice = current_slice
|
| 1798 |
+
|
| 1799 |
+
if single_timestamp_ending:
|
| 1800 |
+
# single timestamp at the end means no speech after the last timestamp.
|
| 1801 |
+
segment_offset = seek_num_frames[prev_idx]
|
| 1802 |
+
else:
|
| 1803 |
+
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
| 1804 |
+
# here we throw away all predictions after the last predicted "end of segment"
|
| 1805 |
+
# since we are cutting right in the middle of an audio
|
| 1806 |
+
last_timestamp_pos = seek_sequence[last_slice - 1].item() - timestamp_begin
|
| 1807 |
+
segment_offset = last_timestamp_pos * input_stride
|
| 1808 |
+
else:
|
| 1809 |
+
# If whisper does not predict any "end of segment" token, then
|
| 1810 |
+
# the whole decoding is considered a segment and we add it to the list of segments
|
| 1811 |
+
timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
|
| 1812 |
+
last_timestamp_pos = seek_num_frames[prev_idx]
|
| 1813 |
+
if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin:
|
| 1814 |
+
# no consecutive timestamps but it has a timestamp; use the last one.
|
| 1815 |
+
last_timestamp_pos = timestamps[-1].item() - timestamp_begin
|
| 1816 |
+
segments = [
|
| 1817 |
+
{
|
| 1818 |
+
"start": time_offset[prev_idx],
|
| 1819 |
+
"end": time_offset[prev_idx] + last_timestamp_pos * time_precision,
|
| 1820 |
+
"tokens": seek_sequence,
|
| 1821 |
+
"result": seek_outputs[idx],
|
| 1822 |
+
}
|
| 1823 |
+
]
|
| 1824 |
+
if return_token_timestamps:
|
| 1825 |
+
segments[-1]["token_timestamps"] = token_timestamps + time_offset[prev_idx]
|
| 1826 |
+
segment_offset = seek_num_frames[prev_idx]
|
| 1827 |
+
|
| 1828 |
+
return segments, segment_offset
|
models/glm_speech_tokenizer/modeling_whisper.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/glm_speech_tokenizer/speech_token_extractor.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.append("../../..")
|
| 4 |
+
import io
|
| 5 |
+
import glob
|
| 6 |
+
import math
|
| 7 |
+
import tarfile
|
| 8 |
+
import torch
|
| 9 |
+
import torchaudio
|
| 10 |
+
import safetensors
|
| 11 |
+
from .configuration_whisper import WhisperVQConfig
|
| 12 |
+
from .modeling_whisper import WhisperVQEncoder, WhisperVQForConditionalGeneration
|
| 13 |
+
from transformers import WhisperFeatureExtractor, WhisperTokenizerFast
|
| 14 |
+
import asyncio
|
| 15 |
+
from .batch_processor import AsyncBatchEngine # 修改为你的路径
|
| 16 |
+
from typing import List, Union, Tuple, Literal, Optional
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SpeechTokenExtractor:
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
model: WhisperVQEncoder,
|
| 23 |
+
feature_extractor: WhisperFeatureExtractor,
|
| 24 |
+
device: Literal["cpu", "cuda", "mps"] | str = "cuda",
|
| 25 |
+
batch_size: int = 32,
|
| 26 |
+
wait_timeout: float = 0.01,
|
| 27 |
+
):
|
| 28 |
+
self.model = model.eval().to(device)
|
| 29 |
+
self.feature_extractor = feature_extractor
|
| 30 |
+
self.device = device
|
| 31 |
+
self.wait_timeout = wait_timeout
|
| 32 |
+
self.dtype = next(model.parameters()).dtype
|
| 33 |
+
|
| 34 |
+
# 帧/采样 stride(用于 pad 对齐 & mask 下采样)
|
| 35 |
+
self.pooling_kernel_size = getattr(model.config, "pooling_kernel_size", 1)
|
| 36 |
+
self.frame_stride = (
|
| 37 |
+
model.conv1.stride[0] *
|
| 38 |
+
model.conv2.stride[0] *
|
| 39 |
+
self.pooling_kernel_size
|
| 40 |
+
)
|
| 41 |
+
self.sample_stride = self.frame_stride * feature_extractor.hop_length
|
| 42 |
+
|
| 43 |
+
# 重采样缓存(放在 device 上)
|
| 44 |
+
self._resamplers: dict[int, torchaudio.transforms.Resample] = {}
|
| 45 |
+
|
| 46 |
+
self._batch_processor = AsyncBatchEngine(
|
| 47 |
+
processing_function=self._batch_extract_async,
|
| 48 |
+
batch_size=batch_size,
|
| 49 |
+
wait_timeout=wait_timeout,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# -------- I/O & 重采样:保持在 device 上 --------
|
| 53 |
+
def _load_audio(self, utt: Union[str, torch.Tensor]) -> torch.Tensor:
|
| 54 |
+
"""读取单条音频 -> 1D float32 waveform(在 self.device 上,采样率16k)。"""
|
| 55 |
+
# print(f"audio type is {type(utt)}")
|
| 56 |
+
if isinstance(utt, torch.Tensor):
|
| 57 |
+
# audio, sr = utt
|
| 58 |
+
audio = utt.to(self.device, non_blocking=True)
|
| 59 |
+
else:
|
| 60 |
+
audio, sr = torchaudio.load(utt) # CPU
|
| 61 |
+
if audio.ndim > 1 and audio.size(0) > 1: # 混单声道
|
| 62 |
+
audio = audio.mean(dim=0, keepdim=True)
|
| 63 |
+
audio = audio.squeeze(0).to(torch.float32).to(self.device, non_blocking=True)
|
| 64 |
+
|
| 65 |
+
return audio # [T] on device
|
| 66 |
+
|
| 67 |
+
# -------- GPU 上做 feature_extractor --------
|
| 68 |
+
def _extract_features_gpu(self, audios: List[torch.Tensor]) -> dict:
|
| 69 |
+
"""
|
| 70 |
+
1) 输入统一转 CPU numpy(float32)(FE 的要求)
|
| 71 |
+
2) 调用 FE,并传 device=self.device,让“输出张量”直接落在 GPU
|
| 72 |
+
3) 若模型是 fp16,仅将 input_features 转 half(mask 不动)
|
| 73 |
+
"""
|
| 74 |
+
# 1) CUDA/CPU Tensor -> CPU numpy
|
| 75 |
+
np_audios = [a.detach().cpu().numpy().astype("float32") for a in audios]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
feats = self.feature_extractor(
|
| 79 |
+
np_audios,
|
| 80 |
+
sampling_rate=16000,
|
| 81 |
+
return_attention_mask=True,
|
| 82 |
+
return_tensors="pt",
|
| 83 |
+
device=self.device, # ← 用得上
|
| 84 |
+
padding="longest",
|
| 85 |
+
pad_to_multiple_of=self.sample_stride,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
feats = {k: (v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else v)
|
| 89 |
+
for k, v in feats.items()}
|
| 90 |
+
|
| 91 |
+
# 3) 半精度对齐(只对 input_features)
|
| 92 |
+
if self.dtype == torch.float16 and "input_features" in feats:
|
| 93 |
+
feats["input_features"] = feats["input_features"].half()
|
| 94 |
+
|
| 95 |
+
return feats
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _forward(self, feats: dict) -> List[List[int]]:
|
| 99 |
+
outputs = self.model(**feats)
|
| 100 |
+
tokens = outputs.quantized_token_ids
|
| 101 |
+
# mask 下采样对齐:conv 下采样 × pooling
|
| 102 |
+
attn = feats["attention_mask"][
|
| 103 |
+
:, :: self.model.conv1.stride[0] * self.model.conv2.stride[0]
|
| 104 |
+
][:, :: self.pooling_kernel_size]
|
| 105 |
+
return [t[m.bool()].tolist() for t, m in zip(tokens, attn)]
|
| 106 |
+
|
| 107 |
+
# -------- 同步批接口 --------
|
| 108 |
+
def extract(self, utts: List[Union[str, torch.Tensor]]) -> List[List[int]]:
|
| 109 |
+
"""
|
| 110 |
+
不做 30s 分片,也不做 microbatch。
|
| 111 |
+
直接:加载/重采样 -> GPU 特征提取 -> 前向 -> 对齐输出。
|
| 112 |
+
"""
|
| 113 |
+
audios = [self._load_audio(u) for u in utts] # list[Tensor(T)] on device
|
| 114 |
+
with torch.inference_mode():
|
| 115 |
+
feats = self._extract_features_gpu(audios) # on device
|
| 116 |
+
return self._forward(feats)
|
| 117 |
+
|
| 118 |
+
# -------- 异步批接口(保持你的返回协议)--------
|
| 119 |
+
async def _batch_extract_async(self, utts: List[Union[str, torch.Tensor]]):
|
| 120 |
+
tokens_list = await asyncio.to_thread(self.extract, utts)
|
| 121 |
+
return [{"tokens": t} for t in tokens_list]
|
| 122 |
+
|
| 123 |
+
async def extract_async(self, utt: Union[str, torch.Tensor]):
|
| 124 |
+
result = await self._batch_processor.add_request(single_input=utt)
|
| 125 |
+
feature = result.get("feature")
|
| 126 |
+
return feature.get("tokens")
|
models/glm_speech_tokenizer/test_speech_token_extractor.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
#!/usr/bin/env python3
|
| 4 |
+
# -*- coding: utf-8 -*-
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
sys.path.append("../../..")
|
| 9 |
+
import asyncio
|
| 10 |
+
import time
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torchaudio
|
| 15 |
+
from transformers import WhisperFeatureExtractor
|
| 16 |
+
from arktts.models.glm_speech_tokenizer.modeling_whisper import WhisperVQEncoder
|
| 17 |
+
from speech_token_extractor import SpeechTokenExtractor # 你实现的类
|
| 18 |
+
_RESAMPLE_CACHE: dict[int, torchaudio.transforms.Resample] = {}
|
| 19 |
+
|
| 20 |
+
def ts() -> str:
|
| 21 |
+
return datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
| 22 |
+
|
| 23 |
+
def sync_cuda(device: str):
|
| 24 |
+
if isinstance(device, str) and device.startswith("cuda") and torch.cuda.is_available():
|
| 25 |
+
torch.cuda.synchronize(device=device)
|
| 26 |
+
|
| 27 |
+
def load_wav_as_tuple(path: str,target_sr: int = 16000):
|
| 28 |
+
"""读取 wav -> (mono_waveform_1d, sample_rate);保持在CPU上交给 extractor 处理。"""
|
| 29 |
+
wav, sr = torchaudio.load(path) # [C, T]
|
| 30 |
+
|
| 31 |
+
if wav.ndim == 2 and wav.size(0) > 1:
|
| 32 |
+
wav = wav.mean(dim=0) # -> [T] 变单声道
|
| 33 |
+
else:
|
| 34 |
+
wav = wav.squeeze(0) # [1, T] -> [T]
|
| 35 |
+
# 保证是连续的 float32(特征器吃 numpy.float32 会更快)
|
| 36 |
+
wav = wav.contiguous().to(torch.float32).cpu()
|
| 37 |
+
if sr != target_sr:
|
| 38 |
+
if sr not in _RESAMPLE_CACHE:
|
| 39 |
+
_RESAMPLE_CACHE[sr] = torchaudio.transforms.Resample(
|
| 40 |
+
orig_freq=sr, new_freq=target_sr
|
| 41 |
+
)
|
| 42 |
+
wav = _RESAMPLE_CACHE[sr](wav.unsqueeze(0)).squeeze(0)
|
| 43 |
+
sr = target_sr
|
| 44 |
+
|
| 45 |
+
# print(f"type wave is {type(wav)}")
|
| 46 |
+
return wav
|
| 47 |
+
|
| 48 |
+
async def main():
|
| 49 |
+
# --- 1️⃣ 路径配置 ---
|
| 50 |
+
MODEL_PATH = "/data/yumu/model/glm-4-voice-tokenizer"
|
| 51 |
+
AUDIO_PATH1 = "/data/yumu/data/audio_data/qiduoduo_tts_out/00000013.wav"
|
| 52 |
+
AUDIO_PATH2 = "/data/yumu/data/audio_data/qiduoduo_tts_out/00000012.wav"
|
| 53 |
+
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 54 |
+
|
| 55 |
+
assert os.path.exists(AUDIO_PATH1), f"音频文件不存在: {AUDIO_PATH1}"
|
| 56 |
+
assert os.path.exists(MODEL_PATH), f"模型路径不存在: {MODEL_PATH}"
|
| 57 |
+
|
| 58 |
+
print(f"[{ts()}] 启动测试")
|
| 59 |
+
print(f" - DEVICE : {DEVICE}")
|
| 60 |
+
print(f" - MODEL_PATH : {MODEL_PATH}")
|
| 61 |
+
print(f" - AUDIO1 : {AUDIO_PATH1}")
|
| 62 |
+
print(f" - AUDIO2 : {AUDIO_PATH2 if os.path.exists(AUDIO_PATH2) else '(不存在,将重复 AUDIO1)'}")
|
| 63 |
+
|
| 64 |
+
# --- 2️⃣ 先把音频读入内存(改动点)---
|
| 65 |
+
audio1 = load_wav_as_tuple(AUDIO_PATH1)
|
| 66 |
+
audio2 = load_wav_as_tuple(AUDIO_PATH2) if os.path.exists(AUDIO_PATH2) else audio1
|
| 67 |
+
|
| 68 |
+
# --- 3️⃣ 加载模型与特征提取器 ---
|
| 69 |
+
print(f"\n[{ts()}] 加载 WhisperVQ 模型与特征提取器中...")
|
| 70 |
+
t0 = time.perf_counter()
|
| 71 |
+
feature_extractor = WhisperFeatureExtractor.from_pretrained(MODEL_PATH)
|
| 72 |
+
|
| 73 |
+
model = WhisperVQEncoder.from_pretrained(MODEL_PATH).eval().to(DEVICE)
|
| 74 |
+
if DEVICE.startswith("cuda"):
|
| 75 |
+
model = model.half() # 半精度仅保留一次
|
| 76 |
+
sync_cuda(DEVICE)
|
| 77 |
+
t1 = time.perf_counter()
|
| 78 |
+
print(f"[{ts()}] 模型加载完成,用时 {(t1 - t0)*1000:.1f} ms")
|
| 79 |
+
|
| 80 |
+
# --- 4️⃣ 初始化提取器 ---
|
| 81 |
+
t0 = time.perf_counter()
|
| 82 |
+
extractor = SpeechTokenExtractor(
|
| 83 |
+
model=model,
|
| 84 |
+
feature_extractor=feature_extractor,
|
| 85 |
+
device=DEVICE,
|
| 86 |
+
batch_size=400,
|
| 87 |
+
wait_timeout=0.01,
|
| 88 |
+
)
|
| 89 |
+
sync_cuda(DEVICE)
|
| 90 |
+
t1 = time.perf_counter()
|
| 91 |
+
print(f"[{ts()}] ✅ SpeechTokenExtractor 初始化完成,用时 {(t1 - t0)*1000:.1f} ms")
|
| 92 |
+
|
| 93 |
+
# --- 5️⃣ 同步测试(传入预加载的 (wav, sr) 元组)---
|
| 94 |
+
print(f"\n[{ts()}] [同步模式] extract() 开始")
|
| 95 |
+
t0 = time.perf_counter()
|
| 96 |
+
sync_tokens_list = extractor.extract([audio1]) # ★ 改:不再传路径
|
| 97 |
+
sync_cuda(DEVICE)
|
| 98 |
+
t1 = time.perf_counter()
|
| 99 |
+
sync_tokens = sync_tokens_list[0]
|
| 100 |
+
print(f"[{ts()}] [同步模式] 完成:{len(sync_tokens)} tokens")
|
| 101 |
+
print(f" - 预览:{sync_tokens[:20]} ...")
|
| 102 |
+
print(f" - 耗时:{(t1 - t0)*1000:.1f} ms (单样本)")
|
| 103 |
+
|
| 104 |
+
# --- 6️⃣ 异步测试(同样传入元组)---
|
| 105 |
+
print(f"\n[{ts()}] [异步模式] extract_async() 并发开始")
|
| 106 |
+
|
| 107 |
+
async def async_worker(audio_utt):
|
| 108 |
+
t_a0 = time.perf_counter()
|
| 109 |
+
print(f"type audio_utt is {type(audio_utt)}")
|
| 110 |
+
tokens = await extractor.extract_async(audio_utt) # ★ 改:不再传路径
|
| 111 |
+
sync_cuda(DEVICE)
|
| 112 |
+
t_a1 = time.perf_counter()
|
| 113 |
+
print(f" · → {len(tokens)} tokens, {(t_a1 - t_a0)*1000:.1f} ms")
|
| 114 |
+
return tokens, (t_a1 - t_a0)
|
| 115 |
+
|
| 116 |
+
# 这里保持你原本的 20+20 并发规模,只是把对象换成内存元组
|
| 117 |
+
test_inputs = [audio1] * 2 + [audio2] * 2
|
| 118 |
+
|
| 119 |
+
t0 = time.perf_counter()
|
| 120 |
+
results = await asyncio.gather(*(async_worker(aud) for aud in test_inputs))
|
| 121 |
+
sync_cuda(DEVICE)
|
| 122 |
+
t1 = time.perf_counter()
|
| 123 |
+
|
| 124 |
+
per_req_ms = [dt * 1000 for _, dt in results]
|
| 125 |
+
all_tokens = [tokens for tokens, _ in results]
|
| 126 |
+
|
| 127 |
+
print(f"[{ts()}] [异步模式] 完成")
|
| 128 |
+
print(f" - 总请求数:{len(results)}")
|
| 129 |
+
print(f" - 总耗时 :{(t1 - t0)*1000:.1f} ms")
|
| 130 |
+
print(f" - 单请求耗时(ms):{[round(x,1) for x in per_req_ms]}")
|
| 131 |
+
print(f" - 平均单请求耗时:{(sum(per_req_ms)/len(per_req_ms)):.1f} ms")
|
| 132 |
+
print(f" - 任一结果预览 :{all_tokens[0][:10]}")
|
| 133 |
+
print(f"\n[{ts()}] ✅ 所有测试完成。")
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
asyncio.run(main())
|
models/glm_speech_tokenizer/utils.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import io
|
| 3 |
+
import glob
|
| 4 |
+
import math
|
| 5 |
+
import tarfile
|
| 6 |
+
import torch
|
| 7 |
+
import torchaudio
|
| 8 |
+
import safetensors
|
| 9 |
+
from .configuration_whisper import WhisperVQConfig
|
| 10 |
+
from .modeling_whisper import WhisperVQEncoder, WhisperVQForConditionalGeneration
|
| 11 |
+
from transformers import WhisperFeatureExtractor, WhisperTokenizerFast
|
| 12 |
+
# import asyncio
|
| 13 |
+
# from ..batch_processor import AsyncBatchEngine # 修改为你的路径
|
| 14 |
+
# from typing import List, Union, Tuple, Literal, Optional
|
| 15 |
+
|
| 16 |
+
def load_quantize_encoder(model_path):
|
| 17 |
+
config = WhisperVQConfig.from_pretrained(model_path)
|
| 18 |
+
config.quantize_encoder_only = True
|
| 19 |
+
model = WhisperVQEncoder(config)
|
| 20 |
+
state_dict = {}
|
| 21 |
+
for path in glob.glob(os.path.join(model_path, "model*.safetensors")):
|
| 22 |
+
with safetensors.safe_open(path, framework="pt", device="cpu") as f:
|
| 23 |
+
for key in f.keys():
|
| 24 |
+
if key.startswith("model.encoder."):
|
| 25 |
+
new_key = key[len("model.encoder."):]
|
| 26 |
+
if new_key.startswith("layer_norm"):
|
| 27 |
+
continue
|
| 28 |
+
if new_key.startswith("layers"):
|
| 29 |
+
layer_id = int(new_key.split(".")[1])
|
| 30 |
+
if layer_id >= config.quantize_position:
|
| 31 |
+
continue
|
| 32 |
+
state_dict[new_key] = f.get_tensor(key)
|
| 33 |
+
model.load_state_dict(state_dict)
|
| 34 |
+
model.eval()
|
| 35 |
+
model.cuda()
|
| 36 |
+
return model
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
_resample_buffer: dict[int, torchaudio.transforms.Resample] = {}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def extract_speech_token(model: WhisperVQEncoder, feature_extractor: WhisperFeatureExtractor, utts,device="cuda"):
|
| 43 |
+
with torch.no_grad():
|
| 44 |
+
audios, indices = [], []
|
| 45 |
+
for idx, utt in enumerate(utts):
|
| 46 |
+
if isinstance(utt, tuple):
|
| 47 |
+
audio, sample_rate = utt
|
| 48 |
+
else:
|
| 49 |
+
audio, sample_rate = torchaudio.load(utt)
|
| 50 |
+
audio = audio.to(device)
|
| 51 |
+
if sample_rate != 16000:
|
| 52 |
+
if sample_rate not in _resample_buffer:
|
| 53 |
+
_resample_buffer[sample_rate] = torchaudio.transforms.Resample(
|
| 54 |
+
orig_freq=sample_rate,
|
| 55 |
+
new_freq=16000
|
| 56 |
+
).to(device)
|
| 57 |
+
audio = _resample_buffer[sample_rate](audio)
|
| 58 |
+
# if audio.shape[0] > 1:
|
| 59 |
+
# audio = audio[:1]
|
| 60 |
+
audio = audio[0]
|
| 61 |
+
audio = audio.cpu().numpy()
|
| 62 |
+
time_step = 0
|
| 63 |
+
while time_step * 16000 < audio.shape[0]:
|
| 64 |
+
audio_segment = audio[time_step * 16000: (time_step + 30) * 16000]
|
| 65 |
+
audios.append(audio_segment)
|
| 66 |
+
indices.append(idx)
|
| 67 |
+
time_step += 30
|
| 68 |
+
pooling_kernel_size = model.config.pooling_kernel_size or 1
|
| 69 |
+
stride = model.conv1.stride[0] * model.conv2.stride[0] * pooling_kernel_size * feature_extractor.hop_length
|
| 70 |
+
all_speech_tokens = [[] for _ in range(len(utts))]
|
| 71 |
+
batch_size = 128
|
| 72 |
+
for start in range(0, len(audios), batch_size):
|
| 73 |
+
features = feature_extractor(audios[start: start + batch_size], sampling_rate=16000,
|
| 74 |
+
return_attention_mask=True, return_tensors="pt", device=device,
|
| 75 |
+
padding="longest", pad_to_multiple_of=stride)
|
| 76 |
+
features = features.to(device=device)
|
| 77 |
+
# ✅ 关键修复:如果模型是FP16,则输入也转为FP16
|
| 78 |
+
if next(model.parameters()).dtype == torch.float16:
|
| 79 |
+
features = {k: v.half() for k, v in features.items()}
|
| 80 |
+
outputs = model(**features)
|
| 81 |
+
speech_tokens = outputs.quantized_token_ids
|
| 82 |
+
attention_mask = features["attention_mask"][:, ::model.conv1.stride[0] * model.conv2.stride[0]]
|
| 83 |
+
attention_mask = attention_mask[:, ::model.config.pooling_kernel_size]
|
| 84 |
+
assert attention_mask.shape == speech_tokens.shape
|
| 85 |
+
for i in range(len(speech_tokens)):
|
| 86 |
+
idx = indices[start + i]
|
| 87 |
+
speech_token = speech_tokens[i][attention_mask[i].bool()].tolist()
|
| 88 |
+
all_speech_tokens[idx].extend(speech_token)
|
| 89 |
+
return all_speech_tokens
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers==4.57.3
|
| 2 |
+
torch==2.8.0
|
| 3 |
+
librosa
|
| 4 |
+
soundfile
|
| 5 |
+
numpy
|