# import ztu_ort_enable_debug import re import time import jieba import torch import ztu_somemodelruntime_rknnlite2 as rknnort import onnxruntime import soundfile as sf import numpy as np from pydub import AudioSegment from pypinyin import lazy_pinyin, Style from tqdm import tqdm vocab_path = "vocab.txt" # The F5-TTS model vocab download path. URL: https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_v1_Base onnx_model_A = "F5_Preprocess.onnx" # The exported onnx model path. onnx_model_B = "F5_Transformer_opset19.onnx" # The exported onnx model path. onnx_model_C = "F5_Decode.onnx" # The exported onnx model path. generated_audio = "./generated_audio-rknn2.wav" test_in_english = False if test_in_english: reference_audio = "basic_ref_en.wav" ref_text = "Some call me nature, others call me mother nature." gen_text = "Some call me Dake, others call me QQ." else: reference_audio = "basic_ref_zh.wav" # The reference audio path. ref_text = "对,这就是我,万人敬仰的太乙真人。" # The ASR result of reference audio. gen_text = "哇,这个辣鸡芯片居然也能运行模型。" # The target TTS. # RANDOM_SEED = 9527 # Set seed to reproduce the generated audio NFE_STEP = 32 # F5-TTS model setting, 0~31 FUSE_NFE = 1 # Maintain the same values as the exported model. SPEED = 1.0 # Set for talking speed. Only works with dynamic_axes=True MODEL_SAMPLE_RATE = 24000 # Do not modify it. HOP_LENGTH = 256 # It affects the generated audio length and speech speed. with open(vocab_path, "r", encoding="utf-8") as f: vocab_char_map = {} for i, char in enumerate(f): vocab_char_map[char[:-1]] = i vocab_size = len(vocab_char_map) # From the official code def convert_char_to_pinyin(text_list, polyphone=True): if jieba.dt.initialized is False: jieba.default_logger.setLevel(50) # CRITICAL jieba.initialize() final_text_list = [] custom_trans = str.maketrans( {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"} ) # add custom trans here, to address oov def is_chinese(c): return ( "\u3100" <= c <= "\u9fff" # common chinese characters ) for text in text_list: char_list = [] text = text.translate(custom_trans) for seg in jieba.cut(text): seg_byte_len = len(bytes(seg, "UTF-8")) if seg_byte_len == len(seg): # if pure alphabets and symbols if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": char_list.append(" ") char_list.extend(seg) elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) for i, c in enumerate(seg): if is_chinese(c): char_list.append(" ") char_list.append(seg_[i]) else: # if mixed characters, alphabets and symbols for c in seg: if ord(c) < 256: char_list.extend(c) elif is_chinese(c): char_list.append(" ") char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)) else: char_list.append(c) final_text_list.append(char_list) return final_text_list # From the official code def list_str_to_idx( text,#: list[str] | list[list[str]], vocab_char_map,#: dict[str, int], # {char: idx} padding_value=-1 ): get_idx = vocab_char_map.get list_idx_tensors = [torch.tensor([get_idx(c, 0) for c in t], dtype=torch.int32) for t in text] text = torch.nn.utils.rnn.pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) return text def normalize_to_int16(audio): max_val = np.max(np.abs(audio)) scaling_factor = 32767.0 / max_val if max_val > 0 else 1.0 return (audio * float(scaling_factor)).astype(np.int16) ort_session_A = onnxruntime.InferenceSession(onnx_model_A, providers=['CPUExecutionProvider'], provider_options=None) model_type = ort_session_A._inputs_meta[0].type in_name_A = ort_session_A.get_inputs() out_name_A = ort_session_A.get_outputs() in_name_A0 = in_name_A[0].name in_name_A1 = in_name_A[1].name in_name_A2 = in_name_A[2].name out_name_A0 = out_name_A[0].name out_name_A1 = out_name_A[1].name out_name_A2 = out_name_A[2].name out_name_A3 = out_name_A[3].name out_name_A4 = out_name_A[4].name out_name_A5 = out_name_A[5].name out_name_A6 = out_name_A[6].name out_name_A7 = out_name_A[7].name ort_session_B = rknnort.InferenceSession(onnx_model_B) in_name_B = ort_session_B.get_inputs() out_name_B = ort_session_B.get_outputs() in_name_B0 = in_name_B[0].name in_name_B1 = in_name_B[1].name in_name_B2 = in_name_B[2].name in_name_B3 = in_name_B[3].name in_name_B4 = in_name_B[4].name in_name_B5 = in_name_B[5].name in_name_B6 = in_name_B[6].name in_name_B7 = in_name_B[7].name out_name_B0 = out_name_B[0].name out_name_B1 = out_name_B[1].name ort_session_C = onnxruntime.InferenceSession(onnx_model_C, providers=['CPUExecutionProvider'], provider_options=None) in_name_C = ort_session_C.get_inputs() out_name_C = ort_session_C.get_outputs() in_name_C0 = in_name_C[0].name in_name_C1 = in_name_C[1].name out_name_C0 = out_name_C[0].name # Load the input audio print(f"\nReference Audio: {reference_audio}") audio = np.array(AudioSegment.from_file(reference_audio).set_channels(1).set_frame_rate(MODEL_SAMPLE_RATE).get_array_of_samples(), dtype=np.float32) audio = normalize_to_int16(audio) # Normalization should happen on the final output, not here audio_len = len(audio) audio = audio.reshape(1, 1, -1) zh_pause_punc = r"。,、;:?!" ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text)) gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text)) ref_audio_len = audio_len // HOP_LENGTH + 1 original_max_duration_val = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / SPEED) fixed_max_duration_val = 1536 # The hardcoded duration speed_adjustment_ratio = fixed_max_duration_val / original_max_duration_val print(f"Original estimated duration: {original_max_duration_val} frames") print(f"Fixed duration: {fixed_max_duration_val} frames") print(f"Speed adjustment ratio: {speed_adjustment_ratio:.3f}") # Force max_duration to the fixed value for the model input max_duration = np.array([fixed_max_duration_val], dtype=np.int64) print(f"zt: hack: force max_duration to {fixed_max_duration_val}") gen_text = convert_char_to_pinyin([ref_text + gen_text]) text_ids = list_str_to_idx(gen_text, vocab_char_map).numpy() time_step = np.array([0], dtype=np.int32) device_type = None print("\n\nRun F5-TTS.") start_count = time.time() noise, rope_cos_q, rope_sin_q, rope_cos_k, rope_sin_k, cat_mel_text, cat_mel_text_drop, ref_signal_len = ort_session_A.run( [out_name_A0, out_name_A1, out_name_A2, out_name_A3, out_name_A4, out_name_A5, out_name_A6, out_name_A7], { in_name_A0: audio, in_name_A1: text_ids, in_name_A2: max_duration }) end_count = time.time() print(f"\nPrepare input data time cost: {end_count - start_count:.3f} seconds") print("NFE_STEP: 0") for i in tqdm(range(0, NFE_STEP - 1, FUSE_NFE)): noise, time_step = ort_session_B.run( [out_name_B0, out_name_B1], { in_name_B0: noise, in_name_B1: rope_cos_q, in_name_B2: rope_sin_q, in_name_B3: rope_cos_k, in_name_B4: rope_sin_k, in_name_B5: cat_mel_text, in_name_B6: cat_mel_text_drop, in_name_B7: time_step }) print(f"NFE_STEP: {i + FUSE_NFE}") start_count = time.time() generated_signal = ort_session_C.run( [out_name_C0], { in_name_C0: noise, in_name_C1: ref_signal_len })[0] end_count = time.time() print(f"Decode time cost: {end_count - start_count:.3f} seconds") # Adjust speed of the generated audio print("\nAdjusting audio speed...") # Flatten signal and ensure it's float32 before normalization generated_signal_flat = generated_signal.flatten().astype(np.float32) # Normalize the generated signal to int16 for pydub generated_signal_int16 = normalize_to_int16(generated_signal_flat) # Create an AudioSegment from the int16 numpy array try: segment = AudioSegment( generated_signal_int16.tobytes(), frame_rate=MODEL_SAMPLE_RATE, sample_width=generated_signal_int16.dtype.itemsize, channels=1 ) # Apply speed adjustment # Note: speedup > 1.0 means faster, < 1.0 means slower if speed_adjustment_ratio != 1.0: print(f"Applying speed change with factor: {speed_adjustment_ratio:.3f}") adjusted_segment = segment.speedup(playback_speed=speed_adjustment_ratio) else: print("No speed adjustment needed.") adjusted_segment = segment # No adjustment needed # Save the adjusted audio using pydub's export adjusted_segment.export(generated_audio, format="wav") print(f"\nSpeed-adjusted audio saved to: {generated_audio}") except Exception as e: print(f"Error during audio speed adjustment or saving: {e}") print("Saving original audio instead.") # Fallback to saving the original audio if adjustment fails sf.write(generated_audio, generated_signal_flat, MODEL_SAMPLE_RATE, format='WAVEX') print(f"\nAudio generation and processing is complete.\n\nONNXRuntime Time Cost in Seconds:\n{end_count - start_count:.3f}")