Spaces:
Sleeping
Sleeping
File size: 10,915 Bytes
6595c29 cda2a68 6595c29 cda2a68 6595c29 cda2a68 6595c29 cda2a68 0e7565e cda2a68 0e7565e cda2a68 0e7565e cda2a68 6595c29 cda2a68 6595c29 cda2a68 6595c29 cda2a68 6595c29 cda2a68 6595c29 cda2a68 6595c29 cda2a68 6595c29 cda2a68 6595c29 cda2a68 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 | import gradio as gr
import os
import torch
import torchaudio
import numpy as np
import onnxruntime
import whisper
import io
import librosa
import math
from huggingface_hub import snapshot_download
from funasr import AutoModel
# Utils
def resample_audio(wav, original_sample_rate, target_sample_rate):
if original_sample_rate != target_sample_rate:
wav = torchaudio.transforms.Resample(
orig_freq=original_sample_rate, new_freq=target_sample_rate
)(wav)
return wav
def energy_norm_fn(wav):
if type(wav) is np.ndarray:
max_data = np.max(np.abs(wav))
wav = wav / max(max_data, 0.01) * 0.999
else:
max_data = torch.max(torch.abs(wav))
wav = wav / max(max_data, 0.01) * 0.999
return wav
def trim_silence(audio, sr, keep_left_time=0.05, keep_right_time=0.22, hop_size=240):
_, index = librosa.effects.trim(audio, top_db=20, frame_length=512, hop_length=128)
num_frames = int(math.ceil((index[1] - index[0]) / hop_size))
left_sil_samples = int(keep_left_time * sr)
right_sil_samples = int(keep_right_time * sr)
wav_len = len(audio)
start_idx = index[0] - left_sil_samples
trim_wav = audio
if start_idx > 0:
trim_wav = trim_wav[start_idx:]
else:
trim_wav = np.pad(
trim_wav, (abs(start_idx), 0), mode="constant", constant_values=0.0
)
wav_len = len(trim_wav)
out_len = int(num_frames * hop_size + (keep_left_time + keep_right_time) * sr)
if out_len < wav_len:
trim_wav = trim_wav[:out_len]
else:
trim_wav = np.pad(
trim_wav, (0, (out_len - wav_len)), mode="constant", constant_values=0.0
)
return trim_wav
class StepAudioTokenizer:
def __init__(self):
model_id = "stepfun-ai/Step-Audio-Tokenizer"
print(f"Loading model from Hugging Face: {model_id}")
self.model_dir = snapshot_download(model_id)
# Load FunASR model
paraformer_dir = os.path.join(self.model_dir, "dengcunqin/speech_paraformer-large_asr_nat-zh-cantonese-en-16k-vocab8501-online")
print(f"Initializing AutoModel from {paraformer_dir}")
self.funasr_model = AutoModel(
model=paraformer_dir,
model_revision="main",
device="cpu",
disable_update=True
)
kms_path = os.path.join(self.model_dir, "linguistic_tokenizer.npy")
cosy_tokenizer_path = os.path.join(self.model_dir, "speech_tokenizer_v1.onnx")
if not os.path.exists(kms_path):
raise FileNotFoundError(f"KMS file not found: {kms_path}")
if not os.path.exists(cosy_tokenizer_path):
raise FileNotFoundError(f"Cosy tokenizer file not found: {cosy_tokenizer_path}")
self.kms = torch.tensor(np.load(kms_path))
providers = ["CPUExecutionProvider"]
session_option = onnxruntime.SessionOptions()
session_option.graph_optimization_level = (
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
)
session_option.intra_op_num_threads = 1
self.ort_session = onnxruntime.InferenceSession(
cosy_tokenizer_path, sess_options=session_option, providers=providers
)
self.chunk_size = [0, 4, 5]
self.encoder_chunk_look_back = 4
self.decoder_chunk_look_back = 1
# Identify the inference function
if hasattr(self.funasr_model, "infer_encoder"):
self.infer_func = self.funasr_model.infer_encoder
elif hasattr(self.funasr_model, "model") and hasattr(self.funasr_model.model, "infer_encoder"):
self.infer_func = self.funasr_model.model.infer_encoder
else:
# Try to find it in the model object if it's wrapped differently
print("Warning: infer_encoder not found directly. Will check at runtime.")
self.infer_func = None
def __call__(self, audio_path):
# Load audio
audio, sr = torchaudio.load(audio_path)
# Mix to mono if stereo
if audio.shape[0] > 1:
audio = audio.mean(dim=0, keepdim=True)
_, vq02, vq06 = self.wav2token(audio, sr, False)
text = self.merge_vq0206_to_token_str(vq02, vq06)
return text
def preprocess_wav(self, audio, sample_rate, enable_trim=True, energy_norm=True):
audio = resample_audio(audio, sample_rate, 16000)
if energy_norm:
audio = energy_norm_fn(audio)
if enable_trim:
audio = audio.cpu().numpy().squeeze(0)
audio = trim_silence(audio, 16000)
audio = torch.from_numpy(audio)
audio = audio.unsqueeze(0)
return audio
def wav2token(self, audio, sample_rate, enable_trim=True, energy_norm=True):
audio = self.preprocess_wav(
audio, sample_rate, enable_trim=enable_trim, energy_norm=energy_norm
)
vq02_ori = self.get_vq02_code(audio)
vq02 = [int(x) + 65536 for x in vq02_ori]
vq06_ori = self.get_vq06_code(audio)
vq06 = [int(x) + 65536 + 1024 for x in vq06_ori]
chunk = 1
chunk_nums = min(len(vq06) // (3 * chunk), len(vq02) // (2 * chunk))
speech_tokens = []
for idx in range(chunk_nums):
speech_tokens += vq02[idx * chunk * 2 : (idx + 1) * chunk * 2]
speech_tokens += vq06[idx * chunk * 3 : (idx + 1) * chunk * 3]
return speech_tokens, vq02_ori, vq06_ori
def get_vq02_code(self, audio):
_tmp_wav = io.BytesIO()
torchaudio.save(_tmp_wav, audio, 16000, format="wav")
_tmp_wav.seek(0)
if self.infer_func is None:
# Last ditch effort to find it
if hasattr(self.funasr_model, "model") and hasattr(self.funasr_model.model, "infer_encoder"):
self.infer_func = self.funasr_model.model.infer_encoder
elif hasattr(self.funasr_model, "infer_encoder"):
self.infer_func = self.funasr_model.infer_encoder
else:
raise RuntimeError("infer_encoder method not found on FunASR model.")
# Note: Depending on funasr version, input might need to be different
# funasr usually accepts: audio path, bytes, or numpy
# If we pass bytes, we might need to ensure the model handles it.
# But let's try passing the BytesIO object wrapped in list as per original code.
try:
res = self.infer_func(
input=[_tmp_wav],
chunk_size=self.chunk_size,
encoder_chunk_look_back=self.encoder_chunk_look_back,
decoder_chunk_look_back=self.decoder_chunk_look_back,
device="cpu",
is_final=True,
cache={}
)
except TypeError as e:
print(f"Error calling infer_encoder: {e}. Trying different arguments.")
# Maybe it doesn't accept some args
res = self.infer_func(
input=[_tmp_wav],
is_final=True
)
if isinstance(res, tuple):
res = res[0]
c_list = []
for j, res_ in enumerate(res):
feat = res_["enc_out"]
if len(feat) > 0:
c_list = self.dump_label([feat], self.kms)[0]
return c_list
def get_vq06_code(self, audio):
def split_audio(audio, chunk_duration=480000):
start = 0
chunks = []
while start < len(audio):
end = min(start + chunk_duration, len(audio))
chunk = audio[start:end]
if len(chunk) < 480:
pass
else:
chunks.append(chunk)
start = end
return chunks
audio = audio.squeeze(0)
chunk_audios = split_audio(audio, chunk_duration=30 * 16000)
speech_tokens = []
for chunk in chunk_audios:
duration = round(chunk.shape[0] / 16000, 2)
feat = whisper.log_mel_spectrogram(chunk, n_mels=128)
feat = feat.unsqueeze(0)
feat_len = np.array([feat.shape[2]], dtype=np.int32)
chunk_token = (
self.ort_session.run(
None,
{
self.ort_session.get_inputs()[0]
.name: feat.detach()
.cpu()
.numpy(),
self.ort_session.get_inputs()[1].name: feat_len,
},
)[0]
.flatten()
.tolist()
)
speech_tokens += chunk_token
return speech_tokens
def kmean_cluster(self, samples, means):
dists = torch.cdist(samples, means)
indices = dists.argmin(dim=1).cpu().numpy()
return indices.tolist()
def dump_label(self, samples, mean):
dims = samples[0].shape[-1]
x_lens = [x.shape[1] for x in samples]
total_len = sum(x_lens)
x_sel = torch.FloatTensor(1, total_len, dims)
start_len = 0
for sample in samples:
sample_len = sample.shape[1]
end_len = start_len + sample_len
x_sel[:, start_len:end_len] = sample
start_len = end_len
dense_x = x_sel.squeeze(0)
indices = self.kmean_cluster(dense_x, mean)
indices_list = []
start_len = 0
for x_len in x_lens:
end_len = start_len + end_len
indices_list.append(indices[start_len:end_len])
return indices_list
def merge_vq0206_to_token_str(self, vq02, vq06):
_vq06 = [1024 + x for x in vq06]
result = []
i = 0
j = 0
while i < len(vq02) - 1 and j < len(_vq06) - 2:
sublist = vq02[i : i + 2] + _vq06[j : j + 3]
result.extend(sublist)
i += 2
j += 3
return "".join([f"<audio_{x}>" for x in result])
tokenizer = None
def process_audio(audio_path):
global tokenizer
if tokenizer is None:
try:
tokenizer = StepAudioTokenizer()
except Exception as e:
return f"Error loading model: {e}"
try:
if not audio_path:
return "Please upload an audio file."
tokens = tokenizer(audio_path)
return tokens
except Exception as e:
import traceback
traceback.print_exc()
return f"Error processing audio: {e}"
if __name__ == "__main__":
demo = gr.Interface(
fn=process_audio,
inputs=gr.Audio(type="filepath", label="Upload WAV"),
outputs=gr.Textbox(label="Token String"),
title="Step Audio Tokenizer",
description="Upload a WAV file to convert it to token string (<audio_XXX>)."
)
demo.launch()
|