tokenizer-demo / app.py
zsc's picture
asis
0e7565e
import gradio as gr
import os
import torch
import torchaudio
import numpy as np
import onnxruntime
import whisper
import io
import librosa
import math
from huggingface_hub import snapshot_download
from funasr import AutoModel
# Utils
def resample_audio(wav, original_sample_rate, target_sample_rate):
if original_sample_rate != target_sample_rate:
wav = torchaudio.transforms.Resample(
orig_freq=original_sample_rate, new_freq=target_sample_rate
)(wav)
return wav
def energy_norm_fn(wav):
if type(wav) is np.ndarray:
max_data = np.max(np.abs(wav))
wav = wav / max(max_data, 0.01) * 0.999
else:
max_data = torch.max(torch.abs(wav))
wav = wav / max(max_data, 0.01) * 0.999
return wav
def trim_silence(audio, sr, keep_left_time=0.05, keep_right_time=0.22, hop_size=240):
_, index = librosa.effects.trim(audio, top_db=20, frame_length=512, hop_length=128)
num_frames = int(math.ceil((index[1] - index[0]) / hop_size))
left_sil_samples = int(keep_left_time * sr)
right_sil_samples = int(keep_right_time * sr)
wav_len = len(audio)
start_idx = index[0] - left_sil_samples
trim_wav = audio
if start_idx > 0:
trim_wav = trim_wav[start_idx:]
else:
trim_wav = np.pad(
trim_wav, (abs(start_idx), 0), mode="constant", constant_values=0.0
)
wav_len = len(trim_wav)
out_len = int(num_frames * hop_size + (keep_left_time + keep_right_time) * sr)
if out_len < wav_len:
trim_wav = trim_wav[:out_len]
else:
trim_wav = np.pad(
trim_wav, (0, (out_len - wav_len)), mode="constant", constant_values=0.0
)
return trim_wav
class StepAudioTokenizer:
def __init__(self):
model_id = "stepfun-ai/Step-Audio-Tokenizer"
print(f"Loading model from Hugging Face: {model_id}")
self.model_dir = snapshot_download(model_id)
# Load FunASR model
paraformer_dir = os.path.join(self.model_dir, "dengcunqin/speech_paraformer-large_asr_nat-zh-cantonese-en-16k-vocab8501-online")
print(f"Initializing AutoModel from {paraformer_dir}")
self.funasr_model = AutoModel(
model=paraformer_dir,
model_revision="main",
device="cpu",
disable_update=True
)
kms_path = os.path.join(self.model_dir, "linguistic_tokenizer.npy")
cosy_tokenizer_path = os.path.join(self.model_dir, "speech_tokenizer_v1.onnx")
if not os.path.exists(kms_path):
raise FileNotFoundError(f"KMS file not found: {kms_path}")
if not os.path.exists(cosy_tokenizer_path):
raise FileNotFoundError(f"Cosy tokenizer file not found: {cosy_tokenizer_path}")
self.kms = torch.tensor(np.load(kms_path))
providers = ["CPUExecutionProvider"]
session_option = onnxruntime.SessionOptions()
session_option.graph_optimization_level = (
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
)
session_option.intra_op_num_threads = 1
self.ort_session = onnxruntime.InferenceSession(
cosy_tokenizer_path, sess_options=session_option, providers=providers
)
self.chunk_size = [0, 4, 5]
self.encoder_chunk_look_back = 4
self.decoder_chunk_look_back = 1
# Identify the inference function
if hasattr(self.funasr_model, "infer_encoder"):
self.infer_func = self.funasr_model.infer_encoder
elif hasattr(self.funasr_model, "model") and hasattr(self.funasr_model.model, "infer_encoder"):
self.infer_func = self.funasr_model.model.infer_encoder
else:
# Try to find it in the model object if it's wrapped differently
print("Warning: infer_encoder not found directly. Will check at runtime.")
self.infer_func = None
def __call__(self, audio_path):
# Load audio
audio, sr = torchaudio.load(audio_path)
# Mix to mono if stereo
if audio.shape[0] > 1:
audio = audio.mean(dim=0, keepdim=True)
_, vq02, vq06 = self.wav2token(audio, sr, False)
text = self.merge_vq0206_to_token_str(vq02, vq06)
return text
def preprocess_wav(self, audio, sample_rate, enable_trim=True, energy_norm=True):
audio = resample_audio(audio, sample_rate, 16000)
if energy_norm:
audio = energy_norm_fn(audio)
if enable_trim:
audio = audio.cpu().numpy().squeeze(0)
audio = trim_silence(audio, 16000)
audio = torch.from_numpy(audio)
audio = audio.unsqueeze(0)
return audio
def wav2token(self, audio, sample_rate, enable_trim=True, energy_norm=True):
audio = self.preprocess_wav(
audio, sample_rate, enable_trim=enable_trim, energy_norm=energy_norm
)
vq02_ori = self.get_vq02_code(audio)
vq02 = [int(x) + 65536 for x in vq02_ori]
vq06_ori = self.get_vq06_code(audio)
vq06 = [int(x) + 65536 + 1024 for x in vq06_ori]
chunk = 1
chunk_nums = min(len(vq06) // (3 * chunk), len(vq02) // (2 * chunk))
speech_tokens = []
for idx in range(chunk_nums):
speech_tokens += vq02[idx * chunk * 2 : (idx + 1) * chunk * 2]
speech_tokens += vq06[idx * chunk * 3 : (idx + 1) * chunk * 3]
return speech_tokens, vq02_ori, vq06_ori
def get_vq02_code(self, audio):
_tmp_wav = io.BytesIO()
torchaudio.save(_tmp_wav, audio, 16000, format="wav")
_tmp_wav.seek(0)
if self.infer_func is None:
# Last ditch effort to find it
if hasattr(self.funasr_model, "model") and hasattr(self.funasr_model.model, "infer_encoder"):
self.infer_func = self.funasr_model.model.infer_encoder
elif hasattr(self.funasr_model, "infer_encoder"):
self.infer_func = self.funasr_model.infer_encoder
else:
raise RuntimeError("infer_encoder method not found on FunASR model.")
# Note: Depending on funasr version, input might need to be different
# funasr usually accepts: audio path, bytes, or numpy
# If we pass bytes, we might need to ensure the model handles it.
# But let's try passing the BytesIO object wrapped in list as per original code.
try:
res = self.infer_func(
input=[_tmp_wav],
chunk_size=self.chunk_size,
encoder_chunk_look_back=self.encoder_chunk_look_back,
decoder_chunk_look_back=self.decoder_chunk_look_back,
device="cpu",
is_final=True,
cache={}
)
except TypeError as e:
print(f"Error calling infer_encoder: {e}. Trying different arguments.")
# Maybe it doesn't accept some args
res = self.infer_func(
input=[_tmp_wav],
is_final=True
)
if isinstance(res, tuple):
res = res[0]
c_list = []
for j, res_ in enumerate(res):
feat = res_["enc_out"]
if len(feat) > 0:
c_list = self.dump_label([feat], self.kms)[0]
return c_list
def get_vq06_code(self, audio):
def split_audio(audio, chunk_duration=480000):
start = 0
chunks = []
while start < len(audio):
end = min(start + chunk_duration, len(audio))
chunk = audio[start:end]
if len(chunk) < 480:
pass
else:
chunks.append(chunk)
start = end
return chunks
audio = audio.squeeze(0)
chunk_audios = split_audio(audio, chunk_duration=30 * 16000)
speech_tokens = []
for chunk in chunk_audios:
duration = round(chunk.shape[0] / 16000, 2)
feat = whisper.log_mel_spectrogram(chunk, n_mels=128)
feat = feat.unsqueeze(0)
feat_len = np.array([feat.shape[2]], dtype=np.int32)
chunk_token = (
self.ort_session.run(
None,
{
self.ort_session.get_inputs()[0]
.name: feat.detach()
.cpu()
.numpy(),
self.ort_session.get_inputs()[1].name: feat_len,
},
)[0]
.flatten()
.tolist()
)
speech_tokens += chunk_token
return speech_tokens
def kmean_cluster(self, samples, means):
dists = torch.cdist(samples, means)
indices = dists.argmin(dim=1).cpu().numpy()
return indices.tolist()
def dump_label(self, samples, mean):
dims = samples[0].shape[-1]
x_lens = [x.shape[1] for x in samples]
total_len = sum(x_lens)
x_sel = torch.FloatTensor(1, total_len, dims)
start_len = 0
for sample in samples:
sample_len = sample.shape[1]
end_len = start_len + sample_len
x_sel[:, start_len:end_len] = sample
start_len = end_len
dense_x = x_sel.squeeze(0)
indices = self.kmean_cluster(dense_x, mean)
indices_list = []
start_len = 0
for x_len in x_lens:
end_len = start_len + end_len
indices_list.append(indices[start_len:end_len])
return indices_list
def merge_vq0206_to_token_str(self, vq02, vq06):
_vq06 = [1024 + x for x in vq06]
result = []
i = 0
j = 0
while i < len(vq02) - 1 and j < len(_vq06) - 2:
sublist = vq02[i : i + 2] + _vq06[j : j + 3]
result.extend(sublist)
i += 2
j += 3
return "".join([f"<audio_{x}>" for x in result])
tokenizer = None
def process_audio(audio_path):
global tokenizer
if tokenizer is None:
try:
tokenizer = StepAudioTokenizer()
except Exception as e:
return f"Error loading model: {e}"
try:
if not audio_path:
return "Please upload an audio file."
tokens = tokenizer(audio_path)
return tokens
except Exception as e:
import traceback
traceback.print_exc()
return f"Error processing audio: {e}"
if __name__ == "__main__":
demo = gr.Interface(
fn=process_audio,
inputs=gr.Audio(type="filepath", label="Upload WAV"),
outputs=gr.Textbox(label="Token String"),
title="Step Audio Tokenizer",
description="Upload a WAV file to convert it to token string (<audio_XXX>)."
)
demo.launch()