import gradio as gr import os import torch import torchaudio import numpy as np import onnxruntime import whisper import io import librosa import math from huggingface_hub import snapshot_download from funasr import AutoModel # Utils def resample_audio(wav, original_sample_rate, target_sample_rate): if original_sample_rate != target_sample_rate: wav = torchaudio.transforms.Resample( orig_freq=original_sample_rate, new_freq=target_sample_rate )(wav) return wav def energy_norm_fn(wav): if type(wav) is np.ndarray: max_data = np.max(np.abs(wav)) wav = wav / max(max_data, 0.01) * 0.999 else: max_data = torch.max(torch.abs(wav)) wav = wav / max(max_data, 0.01) * 0.999 return wav def trim_silence(audio, sr, keep_left_time=0.05, keep_right_time=0.22, hop_size=240): _, index = librosa.effects.trim(audio, top_db=20, frame_length=512, hop_length=128) num_frames = int(math.ceil((index[1] - index[0]) / hop_size)) left_sil_samples = int(keep_left_time * sr) right_sil_samples = int(keep_right_time * sr) wav_len = len(audio) start_idx = index[0] - left_sil_samples trim_wav = audio if start_idx > 0: trim_wav = trim_wav[start_idx:] else: trim_wav = np.pad( trim_wav, (abs(start_idx), 0), mode="constant", constant_values=0.0 ) wav_len = len(trim_wav) out_len = int(num_frames * hop_size + (keep_left_time + keep_right_time) * sr) if out_len < wav_len: trim_wav = trim_wav[:out_len] else: trim_wav = np.pad( trim_wav, (0, (out_len - wav_len)), mode="constant", constant_values=0.0 ) return trim_wav class StepAudioTokenizer: def __init__(self): model_id = "stepfun-ai/Step-Audio-Tokenizer" print(f"Loading model from Hugging Face: {model_id}") self.model_dir = snapshot_download(model_id) # Load FunASR model paraformer_dir = os.path.join(self.model_dir, "dengcunqin/speech_paraformer-large_asr_nat-zh-cantonese-en-16k-vocab8501-online") print(f"Initializing AutoModel from {paraformer_dir}") self.funasr_model = AutoModel( model=paraformer_dir, model_revision="main", device="cpu", disable_update=True ) kms_path = os.path.join(self.model_dir, "linguistic_tokenizer.npy") cosy_tokenizer_path = os.path.join(self.model_dir, "speech_tokenizer_v1.onnx") if not os.path.exists(kms_path): raise FileNotFoundError(f"KMS file not found: {kms_path}") if not os.path.exists(cosy_tokenizer_path): raise FileNotFoundError(f"Cosy tokenizer file not found: {cosy_tokenizer_path}") self.kms = torch.tensor(np.load(kms_path)) providers = ["CPUExecutionProvider"] session_option = onnxruntime.SessionOptions() session_option.graph_optimization_level = ( onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL ) session_option.intra_op_num_threads = 1 self.ort_session = onnxruntime.InferenceSession( cosy_tokenizer_path, sess_options=session_option, providers=providers ) self.chunk_size = [0, 4, 5] self.encoder_chunk_look_back = 4 self.decoder_chunk_look_back = 1 # Identify the inference function if hasattr(self.funasr_model, "infer_encoder"): self.infer_func = self.funasr_model.infer_encoder elif hasattr(self.funasr_model, "model") and hasattr(self.funasr_model.model, "infer_encoder"): self.infer_func = self.funasr_model.model.infer_encoder else: # Try to find it in the model object if it's wrapped differently print("Warning: infer_encoder not found directly. Will check at runtime.") self.infer_func = None def __call__(self, audio_path): # Load audio audio, sr = torchaudio.load(audio_path) # Mix to mono if stereo if audio.shape[0] > 1: audio = audio.mean(dim=0, keepdim=True) _, vq02, vq06 = self.wav2token(audio, sr, False) text = self.merge_vq0206_to_token_str(vq02, vq06) return text def preprocess_wav(self, audio, sample_rate, enable_trim=True, energy_norm=True): audio = resample_audio(audio, sample_rate, 16000) if energy_norm: audio = energy_norm_fn(audio) if enable_trim: audio = audio.cpu().numpy().squeeze(0) audio = trim_silence(audio, 16000) audio = torch.from_numpy(audio) audio = audio.unsqueeze(0) return audio def wav2token(self, audio, sample_rate, enable_trim=True, energy_norm=True): audio = self.preprocess_wav( audio, sample_rate, enable_trim=enable_trim, energy_norm=energy_norm ) vq02_ori = self.get_vq02_code(audio) vq02 = [int(x) + 65536 for x in vq02_ori] vq06_ori = self.get_vq06_code(audio) vq06 = [int(x) + 65536 + 1024 for x in vq06_ori] chunk = 1 chunk_nums = min(len(vq06) // (3 * chunk), len(vq02) // (2 * chunk)) speech_tokens = [] for idx in range(chunk_nums): speech_tokens += vq02[idx * chunk * 2 : (idx + 1) * chunk * 2] speech_tokens += vq06[idx * chunk * 3 : (idx + 1) * chunk * 3] return speech_tokens, vq02_ori, vq06_ori def get_vq02_code(self, audio): _tmp_wav = io.BytesIO() torchaudio.save(_tmp_wav, audio, 16000, format="wav") _tmp_wav.seek(0) if self.infer_func is None: # Last ditch effort to find it if hasattr(self.funasr_model, "model") and hasattr(self.funasr_model.model, "infer_encoder"): self.infer_func = self.funasr_model.model.infer_encoder elif hasattr(self.funasr_model, "infer_encoder"): self.infer_func = self.funasr_model.infer_encoder else: raise RuntimeError("infer_encoder method not found on FunASR model.") # Note: Depending on funasr version, input might need to be different # funasr usually accepts: audio path, bytes, or numpy # If we pass bytes, we might need to ensure the model handles it. # But let's try passing the BytesIO object wrapped in list as per original code. try: res = self.infer_func( input=[_tmp_wav], chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, device="cpu", is_final=True, cache={} ) except TypeError as e: print(f"Error calling infer_encoder: {e}. Trying different arguments.") # Maybe it doesn't accept some args res = self.infer_func( input=[_tmp_wav], is_final=True ) if isinstance(res, tuple): res = res[0] c_list = [] for j, res_ in enumerate(res): feat = res_["enc_out"] if len(feat) > 0: c_list = self.dump_label([feat], self.kms)[0] return c_list def get_vq06_code(self, audio): def split_audio(audio, chunk_duration=480000): start = 0 chunks = [] while start < len(audio): end = min(start + chunk_duration, len(audio)) chunk = audio[start:end] if len(chunk) < 480: pass else: chunks.append(chunk) start = end return chunks audio = audio.squeeze(0) chunk_audios = split_audio(audio, chunk_duration=30 * 16000) speech_tokens = [] for chunk in chunk_audios: duration = round(chunk.shape[0] / 16000, 2) feat = whisper.log_mel_spectrogram(chunk, n_mels=128) feat = feat.unsqueeze(0) feat_len = np.array([feat.shape[2]], dtype=np.int32) chunk_token = ( self.ort_session.run( None, { self.ort_session.get_inputs()[0] .name: feat.detach() .cpu() .numpy(), self.ort_session.get_inputs()[1].name: feat_len, }, )[0] .flatten() .tolist() ) speech_tokens += chunk_token return speech_tokens def kmean_cluster(self, samples, means): dists = torch.cdist(samples, means) indices = dists.argmin(dim=1).cpu().numpy() return indices.tolist() def dump_label(self, samples, mean): dims = samples[0].shape[-1] x_lens = [x.shape[1] for x in samples] total_len = sum(x_lens) x_sel = torch.FloatTensor(1, total_len, dims) start_len = 0 for sample in samples: sample_len = sample.shape[1] end_len = start_len + sample_len x_sel[:, start_len:end_len] = sample start_len = end_len dense_x = x_sel.squeeze(0) indices = self.kmean_cluster(dense_x, mean) indices_list = [] start_len = 0 for x_len in x_lens: end_len = start_len + end_len indices_list.append(indices[start_len:end_len]) return indices_list def merge_vq0206_to_token_str(self, vq02, vq06): _vq06 = [1024 + x for x in vq06] result = [] i = 0 j = 0 while i < len(vq02) - 1 and j < len(_vq06) - 2: sublist = vq02[i : i + 2] + _vq06[j : j + 3] result.extend(sublist) i += 2 j += 3 return "".join([f"" for x in result]) tokenizer = None def process_audio(audio_path): global tokenizer if tokenizer is None: try: tokenizer = StepAudioTokenizer() except Exception as e: return f"Error loading model: {e}" try: if not audio_path: return "Please upload an audio file." tokens = tokenizer(audio_path) return tokens except Exception as e: import traceback traceback.print_exc() return f"Error processing audio: {e}" if __name__ == "__main__": demo = gr.Interface( fn=process_audio, inputs=gr.Audio(type="filepath", label="Upload WAV"), outputs=gr.Textbox(label="Token String"), title="Step Audio Tokenizer", description="Upload a WAV file to convert it to token string ()." ) demo.launch()