Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| import onnxruntime | |
| import whisper | |
| import io | |
| import librosa | |
| import math | |
| from huggingface_hub import snapshot_download | |
| from funasr import AutoModel | |
| # Utils | |
| def resample_audio(wav, original_sample_rate, target_sample_rate): | |
| if original_sample_rate != target_sample_rate: | |
| wav = torchaudio.transforms.Resample( | |
| orig_freq=original_sample_rate, new_freq=target_sample_rate | |
| )(wav) | |
| return wav | |
| def energy_norm_fn(wav): | |
| if type(wav) is np.ndarray: | |
| max_data = np.max(np.abs(wav)) | |
| wav = wav / max(max_data, 0.01) * 0.999 | |
| else: | |
| max_data = torch.max(torch.abs(wav)) | |
| wav = wav / max(max_data, 0.01) * 0.999 | |
| return wav | |
| def trim_silence(audio, sr, keep_left_time=0.05, keep_right_time=0.22, hop_size=240): | |
| _, index = librosa.effects.trim(audio, top_db=20, frame_length=512, hop_length=128) | |
| num_frames = int(math.ceil((index[1] - index[0]) / hop_size)) | |
| left_sil_samples = int(keep_left_time * sr) | |
| right_sil_samples = int(keep_right_time * sr) | |
| wav_len = len(audio) | |
| start_idx = index[0] - left_sil_samples | |
| trim_wav = audio | |
| if start_idx > 0: | |
| trim_wav = trim_wav[start_idx:] | |
| else: | |
| trim_wav = np.pad( | |
| trim_wav, (abs(start_idx), 0), mode="constant", constant_values=0.0 | |
| ) | |
| wav_len = len(trim_wav) | |
| out_len = int(num_frames * hop_size + (keep_left_time + keep_right_time) * sr) | |
| if out_len < wav_len: | |
| trim_wav = trim_wav[:out_len] | |
| else: | |
| trim_wav = np.pad( | |
| trim_wav, (0, (out_len - wav_len)), mode="constant", constant_values=0.0 | |
| ) | |
| return trim_wav | |
| class StepAudioTokenizer: | |
| def __init__(self): | |
| model_id = "stepfun-ai/Step-Audio-Tokenizer" | |
| print(f"Loading model from Hugging Face: {model_id}") | |
| self.model_dir = snapshot_download(model_id) | |
| # Load FunASR model | |
| paraformer_dir = os.path.join(self.model_dir, "dengcunqin/speech_paraformer-large_asr_nat-zh-cantonese-en-16k-vocab8501-online") | |
| print(f"Initializing AutoModel from {paraformer_dir}") | |
| self.funasr_model = AutoModel( | |
| model=paraformer_dir, | |
| model_revision="main", | |
| device="cpu", | |
| disable_update=True | |
| ) | |
| kms_path = os.path.join(self.model_dir, "linguistic_tokenizer.npy") | |
| cosy_tokenizer_path = os.path.join(self.model_dir, "speech_tokenizer_v1.onnx") | |
| if not os.path.exists(kms_path): | |
| raise FileNotFoundError(f"KMS file not found: {kms_path}") | |
| if not os.path.exists(cosy_tokenizer_path): | |
| raise FileNotFoundError(f"Cosy tokenizer file not found: {cosy_tokenizer_path}") | |
| self.kms = torch.tensor(np.load(kms_path)) | |
| providers = ["CPUExecutionProvider"] | |
| session_option = onnxruntime.SessionOptions() | |
| session_option.graph_optimization_level = ( | |
| onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| ) | |
| session_option.intra_op_num_threads = 1 | |
| self.ort_session = onnxruntime.InferenceSession( | |
| cosy_tokenizer_path, sess_options=session_option, providers=providers | |
| ) | |
| self.chunk_size = [0, 4, 5] | |
| self.encoder_chunk_look_back = 4 | |
| self.decoder_chunk_look_back = 1 | |
| # Identify the inference function | |
| if hasattr(self.funasr_model, "infer_encoder"): | |
| self.infer_func = self.funasr_model.infer_encoder | |
| elif hasattr(self.funasr_model, "model") and hasattr(self.funasr_model.model, "infer_encoder"): | |
| self.infer_func = self.funasr_model.model.infer_encoder | |
| else: | |
| # Try to find it in the model object if it's wrapped differently | |
| print("Warning: infer_encoder not found directly. Will check at runtime.") | |
| self.infer_func = None | |
| def __call__(self, audio_path): | |
| # Load audio | |
| audio, sr = torchaudio.load(audio_path) | |
| # Mix to mono if stereo | |
| if audio.shape[0] > 1: | |
| audio = audio.mean(dim=0, keepdim=True) | |
| _, vq02, vq06 = self.wav2token(audio, sr, False) | |
| text = self.merge_vq0206_to_token_str(vq02, vq06) | |
| return text | |
| def preprocess_wav(self, audio, sample_rate, enable_trim=True, energy_norm=True): | |
| audio = resample_audio(audio, sample_rate, 16000) | |
| if energy_norm: | |
| audio = energy_norm_fn(audio) | |
| if enable_trim: | |
| audio = audio.cpu().numpy().squeeze(0) | |
| audio = trim_silence(audio, 16000) | |
| audio = torch.from_numpy(audio) | |
| audio = audio.unsqueeze(0) | |
| return audio | |
| def wav2token(self, audio, sample_rate, enable_trim=True, energy_norm=True): | |
| audio = self.preprocess_wav( | |
| audio, sample_rate, enable_trim=enable_trim, energy_norm=energy_norm | |
| ) | |
| vq02_ori = self.get_vq02_code(audio) | |
| vq02 = [int(x) + 65536 for x in vq02_ori] | |
| vq06_ori = self.get_vq06_code(audio) | |
| vq06 = [int(x) + 65536 + 1024 for x in vq06_ori] | |
| chunk = 1 | |
| chunk_nums = min(len(vq06) // (3 * chunk), len(vq02) // (2 * chunk)) | |
| speech_tokens = [] | |
| for idx in range(chunk_nums): | |
| speech_tokens += vq02[idx * chunk * 2 : (idx + 1) * chunk * 2] | |
| speech_tokens += vq06[idx * chunk * 3 : (idx + 1) * chunk * 3] | |
| return speech_tokens, vq02_ori, vq06_ori | |
| def get_vq02_code(self, audio): | |
| _tmp_wav = io.BytesIO() | |
| torchaudio.save(_tmp_wav, audio, 16000, format="wav") | |
| _tmp_wav.seek(0) | |
| if self.infer_func is None: | |
| # Last ditch effort to find it | |
| if hasattr(self.funasr_model, "model") and hasattr(self.funasr_model.model, "infer_encoder"): | |
| self.infer_func = self.funasr_model.model.infer_encoder | |
| elif hasattr(self.funasr_model, "infer_encoder"): | |
| self.infer_func = self.funasr_model.infer_encoder | |
| else: | |
| raise RuntimeError("infer_encoder method not found on FunASR model.") | |
| # Note: Depending on funasr version, input might need to be different | |
| # funasr usually accepts: audio path, bytes, or numpy | |
| # If we pass bytes, we might need to ensure the model handles it. | |
| # But let's try passing the BytesIO object wrapped in list as per original code. | |
| try: | |
| res = self.infer_func( | |
| input=[_tmp_wav], | |
| chunk_size=self.chunk_size, | |
| encoder_chunk_look_back=self.encoder_chunk_look_back, | |
| decoder_chunk_look_back=self.decoder_chunk_look_back, | |
| device="cpu", | |
| is_final=True, | |
| cache={} | |
| ) | |
| except TypeError as e: | |
| print(f"Error calling infer_encoder: {e}. Trying different arguments.") | |
| # Maybe it doesn't accept some args | |
| res = self.infer_func( | |
| input=[_tmp_wav], | |
| is_final=True | |
| ) | |
| if isinstance(res, tuple): | |
| res = res[0] | |
| c_list = [] | |
| for j, res_ in enumerate(res): | |
| feat = res_["enc_out"] | |
| if len(feat) > 0: | |
| c_list = self.dump_label([feat], self.kms)[0] | |
| return c_list | |
| def get_vq06_code(self, audio): | |
| def split_audio(audio, chunk_duration=480000): | |
| start = 0 | |
| chunks = [] | |
| while start < len(audio): | |
| end = min(start + chunk_duration, len(audio)) | |
| chunk = audio[start:end] | |
| if len(chunk) < 480: | |
| pass | |
| else: | |
| chunks.append(chunk) | |
| start = end | |
| return chunks | |
| audio = audio.squeeze(0) | |
| chunk_audios = split_audio(audio, chunk_duration=30 * 16000) | |
| speech_tokens = [] | |
| for chunk in chunk_audios: | |
| duration = round(chunk.shape[0] / 16000, 2) | |
| feat = whisper.log_mel_spectrogram(chunk, n_mels=128) | |
| feat = feat.unsqueeze(0) | |
| feat_len = np.array([feat.shape[2]], dtype=np.int32) | |
| chunk_token = ( | |
| self.ort_session.run( | |
| None, | |
| { | |
| self.ort_session.get_inputs()[0] | |
| .name: feat.detach() | |
| .cpu() | |
| .numpy(), | |
| self.ort_session.get_inputs()[1].name: feat_len, | |
| }, | |
| )[0] | |
| .flatten() | |
| .tolist() | |
| ) | |
| speech_tokens += chunk_token | |
| return speech_tokens | |
| def kmean_cluster(self, samples, means): | |
| dists = torch.cdist(samples, means) | |
| indices = dists.argmin(dim=1).cpu().numpy() | |
| return indices.tolist() | |
| def dump_label(self, samples, mean): | |
| dims = samples[0].shape[-1] | |
| x_lens = [x.shape[1] for x in samples] | |
| total_len = sum(x_lens) | |
| x_sel = torch.FloatTensor(1, total_len, dims) | |
| start_len = 0 | |
| for sample in samples: | |
| sample_len = sample.shape[1] | |
| end_len = start_len + sample_len | |
| x_sel[:, start_len:end_len] = sample | |
| start_len = end_len | |
| dense_x = x_sel.squeeze(0) | |
| indices = self.kmean_cluster(dense_x, mean) | |
| indices_list = [] | |
| start_len = 0 | |
| for x_len in x_lens: | |
| end_len = start_len + end_len | |
| indices_list.append(indices[start_len:end_len]) | |
| return indices_list | |
| def merge_vq0206_to_token_str(self, vq02, vq06): | |
| _vq06 = [1024 + x for x in vq06] | |
| result = [] | |
| i = 0 | |
| j = 0 | |
| while i < len(vq02) - 1 and j < len(_vq06) - 2: | |
| sublist = vq02[i : i + 2] + _vq06[j : j + 3] | |
| result.extend(sublist) | |
| i += 2 | |
| j += 3 | |
| return "".join([f"<audio_{x}>" for x in result]) | |
| tokenizer = None | |
| def process_audio(audio_path): | |
| global tokenizer | |
| if tokenizer is None: | |
| try: | |
| tokenizer = StepAudioTokenizer() | |
| except Exception as e: | |
| return f"Error loading model: {e}" | |
| try: | |
| if not audio_path: | |
| return "Please upload an audio file." | |
| tokens = tokenizer(audio_path) | |
| return tokens | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return f"Error processing audio: {e}" | |
| if __name__ == "__main__": | |
| demo = gr.Interface( | |
| fn=process_audio, | |
| inputs=gr.Audio(type="filepath", label="Upload WAV"), | |
| outputs=gr.Textbox(label="Token String"), | |
| title="Step Audio Tokenizer", | |
| description="Upload a WAV file to convert it to token string (<audio_XXX>)." | |
| ) | |
| demo.launch() | |