F5-TTS-RKNN2 / F5-TTS-ONNX-Inference-rknn2.py
happyme531's picture
Upload 11 files
119f8ea verified
# import ztu_ort_enable_debug
import re
import time
import jieba
import torch
import ztu_somemodelruntime_rknnlite2 as rknnort
import onnxruntime
import soundfile as sf
import numpy as np
from pydub import AudioSegment
from pypinyin import lazy_pinyin, Style
from tqdm import tqdm
vocab_path = "vocab.txt" # The F5-TTS model vocab download path. URL: https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_v1_Base
onnx_model_A = "F5_Preprocess.onnx" # The exported onnx model path.
onnx_model_B = "F5_Transformer_opset19.onnx" # The exported onnx model path.
onnx_model_C = "F5_Decode.onnx" # The exported onnx model path.
generated_audio = "./generated_audio-rknn2.wav"
test_in_english = False
if test_in_english:
reference_audio = "basic_ref_en.wav"
ref_text = "Some call me nature, others call me mother nature."
gen_text = "Some call me Dake, others call me QQ."
else:
reference_audio = "basic_ref_zh.wav" # The reference audio path.
ref_text = "对,这就是我,万人敬仰的太乙真人。" # The ASR result of reference audio.
gen_text = "哇,这个辣鸡芯片居然也能运行模型。" # The target TTS.
# RANDOM_SEED = 9527 # Set seed to reproduce the generated audio
NFE_STEP = 32 # F5-TTS model setting, 0~31
FUSE_NFE = 1 # Maintain the same values as the exported model.
SPEED = 1.0 # Set for talking speed. Only works with dynamic_axes=True
MODEL_SAMPLE_RATE = 24000 # Do not modify it.
HOP_LENGTH = 256 # It affects the generated audio length and speech speed.
with open(vocab_path, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
# From the official code
def convert_char_to_pinyin(text_list, polyphone=True):
if jieba.dt.initialized is False:
jieba.default_logger.setLevel(50) # CRITICAL
jieba.initialize()
final_text_list = []
custom_trans = str.maketrans(
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
) # add custom trans here, to address oov
def is_chinese(c):
return (
"\u3100" <= c <= "\u9fff" # common chinese characters
)
for text in text_list:
char_list = []
text = text.translate(custom_trans)
for seg in jieba.cut(text):
seg_byte_len = len(bytes(seg, "UTF-8"))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
char_list.append(" ")
char_list.extend(seg)
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
for i, c in enumerate(seg):
if is_chinese(c):
char_list.append(" ")
char_list.append(seg_[i])
else: # if mixed characters, alphabets and symbols
for c in seg:
if ord(c) < 256:
char_list.extend(c)
elif is_chinese(c):
char_list.append(" ")
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
else:
char_list.append(c)
final_text_list.append(char_list)
return final_text_list
# From the official code
def list_str_to_idx(
text,#: list[str] | list[list[str]],
vocab_char_map,#: dict[str, int], # {char: idx}
padding_value=-1
):
get_idx = vocab_char_map.get
list_idx_tensors = [torch.tensor([get_idx(c, 0) for c in t], dtype=torch.int32) for t in text]
text = torch.nn.utils.rnn.pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
return text
def normalize_to_int16(audio):
max_val = np.max(np.abs(audio))
scaling_factor = 32767.0 / max_val if max_val > 0 else 1.0
return (audio * float(scaling_factor)).astype(np.int16)
ort_session_A = onnxruntime.InferenceSession(onnx_model_A, providers=['CPUExecutionProvider'], provider_options=None)
model_type = ort_session_A._inputs_meta[0].type
in_name_A = ort_session_A.get_inputs()
out_name_A = ort_session_A.get_outputs()
in_name_A0 = in_name_A[0].name
in_name_A1 = in_name_A[1].name
in_name_A2 = in_name_A[2].name
out_name_A0 = out_name_A[0].name
out_name_A1 = out_name_A[1].name
out_name_A2 = out_name_A[2].name
out_name_A3 = out_name_A[3].name
out_name_A4 = out_name_A[4].name
out_name_A5 = out_name_A[5].name
out_name_A6 = out_name_A[6].name
out_name_A7 = out_name_A[7].name
ort_session_B = rknnort.InferenceSession(onnx_model_B)
in_name_B = ort_session_B.get_inputs()
out_name_B = ort_session_B.get_outputs()
in_name_B0 = in_name_B[0].name
in_name_B1 = in_name_B[1].name
in_name_B2 = in_name_B[2].name
in_name_B3 = in_name_B[3].name
in_name_B4 = in_name_B[4].name
in_name_B5 = in_name_B[5].name
in_name_B6 = in_name_B[6].name
in_name_B7 = in_name_B[7].name
out_name_B0 = out_name_B[0].name
out_name_B1 = out_name_B[1].name
ort_session_C = onnxruntime.InferenceSession(onnx_model_C, providers=['CPUExecutionProvider'], provider_options=None)
in_name_C = ort_session_C.get_inputs()
out_name_C = ort_session_C.get_outputs()
in_name_C0 = in_name_C[0].name
in_name_C1 = in_name_C[1].name
out_name_C0 = out_name_C[0].name
# Load the input audio
print(f"\nReference Audio: {reference_audio}")
audio = np.array(AudioSegment.from_file(reference_audio).set_channels(1).set_frame_rate(MODEL_SAMPLE_RATE).get_array_of_samples(), dtype=np.float32)
audio = normalize_to_int16(audio) # Normalization should happen on the final output, not here
audio_len = len(audio)
audio = audio.reshape(1, 1, -1)
zh_pause_punc = r"。,、;:?!"
ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
ref_audio_len = audio_len // HOP_LENGTH + 1
original_max_duration_val = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / SPEED)
fixed_max_duration_val = 1536 # The hardcoded duration
speed_adjustment_ratio = fixed_max_duration_val / original_max_duration_val
print(f"Original estimated duration: {original_max_duration_val} frames")
print(f"Fixed duration: {fixed_max_duration_val} frames")
print(f"Speed adjustment ratio: {speed_adjustment_ratio:.3f}")
# Force max_duration to the fixed value for the model input
max_duration = np.array([fixed_max_duration_val], dtype=np.int64)
print(f"zt: hack: force max_duration to {fixed_max_duration_val}")
gen_text = convert_char_to_pinyin([ref_text + gen_text])
text_ids = list_str_to_idx(gen_text, vocab_char_map).numpy()
time_step = np.array([0], dtype=np.int32)
device_type = None
print("\n\nRun F5-TTS.")
start_count = time.time()
noise, rope_cos_q, rope_sin_q, rope_cos_k, rope_sin_k, cat_mel_text, cat_mel_text_drop, ref_signal_len = ort_session_A.run(
[out_name_A0, out_name_A1, out_name_A2, out_name_A3, out_name_A4, out_name_A5, out_name_A6, out_name_A7],
{
in_name_A0: audio,
in_name_A1: text_ids,
in_name_A2: max_duration
})
end_count = time.time()
print(f"\nPrepare input data time cost: {end_count - start_count:.3f} seconds")
print("NFE_STEP: 0")
for i in tqdm(range(0, NFE_STEP - 1, FUSE_NFE)):
noise, time_step = ort_session_B.run(
[out_name_B0, out_name_B1],
{
in_name_B0: noise,
in_name_B1: rope_cos_q,
in_name_B2: rope_sin_q,
in_name_B3: rope_cos_k,
in_name_B4: rope_sin_k,
in_name_B5: cat_mel_text,
in_name_B6: cat_mel_text_drop,
in_name_B7: time_step
})
print(f"NFE_STEP: {i + FUSE_NFE}")
start_count = time.time()
generated_signal = ort_session_C.run(
[out_name_C0],
{
in_name_C0: noise,
in_name_C1: ref_signal_len
})[0]
end_count = time.time()
print(f"Decode time cost: {end_count - start_count:.3f} seconds")
# Adjust speed of the generated audio
print("\nAdjusting audio speed...")
# Flatten signal and ensure it's float32 before normalization
generated_signal_flat = generated_signal.flatten().astype(np.float32)
# Normalize the generated signal to int16 for pydub
generated_signal_int16 = normalize_to_int16(generated_signal_flat)
# Create an AudioSegment from the int16 numpy array
try:
segment = AudioSegment(
generated_signal_int16.tobytes(),
frame_rate=MODEL_SAMPLE_RATE,
sample_width=generated_signal_int16.dtype.itemsize,
channels=1
)
# Apply speed adjustment
# Note: speedup > 1.0 means faster, < 1.0 means slower
if speed_adjustment_ratio != 1.0:
print(f"Applying speed change with factor: {speed_adjustment_ratio:.3f}")
adjusted_segment = segment.speedup(playback_speed=speed_adjustment_ratio)
else:
print("No speed adjustment needed.")
adjusted_segment = segment # No adjustment needed
# Save the adjusted audio using pydub's export
adjusted_segment.export(generated_audio, format="wav")
print(f"\nSpeed-adjusted audio saved to: {generated_audio}")
except Exception as e:
print(f"Error during audio speed adjustment or saving: {e}")
print("Saving original audio instead.")
# Fallback to saving the original audio if adjustment fails
sf.write(generated_audio, generated_signal_flat, MODEL_SAMPLE_RATE, format='WAVEX')
print(f"\nAudio generation and processing is complete.\n\nONNXRuntime Time Cost in Seconds:\n{end_count - start_count:.3f}")