| | |
| | import re |
| | import time |
| | import jieba |
| | import torch |
| | import ztu_somemodelruntime_rknnlite2 as rknnort |
| | import onnxruntime |
| | import soundfile as sf |
| | import numpy as np |
| | from pydub import AudioSegment |
| | from pypinyin import lazy_pinyin, Style |
| | from tqdm import tqdm |
| |
|
| | vocab_path = "vocab.txt" |
| | onnx_model_A = "F5_Preprocess.onnx" |
| | onnx_model_B = "F5_Transformer_opset19.onnx" |
| | onnx_model_C = "F5_Decode.onnx" |
| | generated_audio = "./generated_audio-rknn2.wav" |
| | test_in_english = False |
| |
|
| | if test_in_english: |
| | reference_audio = "basic_ref_en.wav" |
| | ref_text = "Some call me nature, others call me mother nature." |
| | gen_text = "Some call me Dake, others call me QQ." |
| | else: |
| | reference_audio = "basic_ref_zh.wav" |
| | ref_text = "对,这就是我,万人敬仰的太乙真人。" |
| | gen_text = "哇,这个辣鸡芯片居然也能运行模型。" |
| |
|
| |
|
| | |
| | NFE_STEP = 32 |
| | FUSE_NFE = 1 |
| | SPEED = 1.0 |
| | MODEL_SAMPLE_RATE = 24000 |
| | HOP_LENGTH = 256 |
| |
|
| |
|
| |
|
| | with open(vocab_path, "r", encoding="utf-8") as f: |
| | vocab_char_map = {} |
| | for i, char in enumerate(f): |
| | vocab_char_map[char[:-1]] = i |
| | vocab_size = len(vocab_char_map) |
| |
|
| |
|
| | |
| | def convert_char_to_pinyin(text_list, polyphone=True): |
| | if jieba.dt.initialized is False: |
| | jieba.default_logger.setLevel(50) |
| | jieba.initialize() |
| |
|
| | final_text_list = [] |
| | custom_trans = str.maketrans( |
| | {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"} |
| | ) |
| |
|
| | def is_chinese(c): |
| | return ( |
| | "\u3100" <= c <= "\u9fff" |
| | ) |
| |
|
| | for text in text_list: |
| | char_list = [] |
| | text = text.translate(custom_trans) |
| | for seg in jieba.cut(text): |
| | seg_byte_len = len(bytes(seg, "UTF-8")) |
| | if seg_byte_len == len(seg): |
| | if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": |
| | char_list.append(" ") |
| | char_list.extend(seg) |
| | elif polyphone and seg_byte_len == 3 * len(seg): |
| | seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) |
| | for i, c in enumerate(seg): |
| | if is_chinese(c): |
| | char_list.append(" ") |
| | char_list.append(seg_[i]) |
| | else: |
| | for c in seg: |
| | if ord(c) < 256: |
| | char_list.extend(c) |
| | elif is_chinese(c): |
| | char_list.append(" ") |
| | char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)) |
| | else: |
| | char_list.append(c) |
| | final_text_list.append(char_list) |
| | return final_text_list |
| |
|
| |
|
| | |
| | def list_str_to_idx( |
| | text, |
| | vocab_char_map, |
| | padding_value=-1 |
| | ): |
| | get_idx = vocab_char_map.get |
| | list_idx_tensors = [torch.tensor([get_idx(c, 0) for c in t], dtype=torch.int32) for t in text] |
| | text = torch.nn.utils.rnn.pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) |
| | return text |
| |
|
| |
|
| | def normalize_to_int16(audio): |
| | max_val = np.max(np.abs(audio)) |
| | scaling_factor = 32767.0 / max_val if max_val > 0 else 1.0 |
| | return (audio * float(scaling_factor)).astype(np.int16) |
| |
|
| | ort_session_A = onnxruntime.InferenceSession(onnx_model_A, providers=['CPUExecutionProvider'], provider_options=None) |
| | model_type = ort_session_A._inputs_meta[0].type |
| | in_name_A = ort_session_A.get_inputs() |
| | out_name_A = ort_session_A.get_outputs() |
| | in_name_A0 = in_name_A[0].name |
| | in_name_A1 = in_name_A[1].name |
| | in_name_A2 = in_name_A[2].name |
| | out_name_A0 = out_name_A[0].name |
| | out_name_A1 = out_name_A[1].name |
| | out_name_A2 = out_name_A[2].name |
| | out_name_A3 = out_name_A[3].name |
| | out_name_A4 = out_name_A[4].name |
| | out_name_A5 = out_name_A[5].name |
| | out_name_A6 = out_name_A[6].name |
| | out_name_A7 = out_name_A[7].name |
| |
|
| |
|
| | ort_session_B = rknnort.InferenceSession(onnx_model_B) |
| |
|
| | in_name_B = ort_session_B.get_inputs() |
| | out_name_B = ort_session_B.get_outputs() |
| | in_name_B0 = in_name_B[0].name |
| | in_name_B1 = in_name_B[1].name |
| | in_name_B2 = in_name_B[2].name |
| | in_name_B3 = in_name_B[3].name |
| | in_name_B4 = in_name_B[4].name |
| | in_name_B5 = in_name_B[5].name |
| | in_name_B6 = in_name_B[6].name |
| | in_name_B7 = in_name_B[7].name |
| | out_name_B0 = out_name_B[0].name |
| | out_name_B1 = out_name_B[1].name |
| |
|
| | ort_session_C = onnxruntime.InferenceSession(onnx_model_C, providers=['CPUExecutionProvider'], provider_options=None) |
| | in_name_C = ort_session_C.get_inputs() |
| | out_name_C = ort_session_C.get_outputs() |
| | in_name_C0 = in_name_C[0].name |
| | in_name_C1 = in_name_C[1].name |
| | out_name_C0 = out_name_C[0].name |
| |
|
| | |
| | print(f"\nReference Audio: {reference_audio}") |
| | audio = np.array(AudioSegment.from_file(reference_audio).set_channels(1).set_frame_rate(MODEL_SAMPLE_RATE).get_array_of_samples(), dtype=np.float32) |
| | audio = normalize_to_int16(audio) |
| | audio_len = len(audio) |
| | audio = audio.reshape(1, 1, -1) |
| |
|
| | zh_pause_punc = r"。,、;:?!" |
| | ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text)) |
| | gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text)) |
| | ref_audio_len = audio_len // HOP_LENGTH + 1 |
| |
|
| |
|
| | original_max_duration_val = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / SPEED) |
| | fixed_max_duration_val = 1536 |
| | speed_adjustment_ratio = fixed_max_duration_val / original_max_duration_val |
| |
|
| | print(f"Original estimated duration: {original_max_duration_val} frames") |
| | print(f"Fixed duration: {fixed_max_duration_val} frames") |
| | print(f"Speed adjustment ratio: {speed_adjustment_ratio:.3f}") |
| |
|
| | |
| | max_duration = np.array([fixed_max_duration_val], dtype=np.int64) |
| |
|
| | print(f"zt: hack: force max_duration to {fixed_max_duration_val}") |
| | gen_text = convert_char_to_pinyin([ref_text + gen_text]) |
| | text_ids = list_str_to_idx(gen_text, vocab_char_map).numpy() |
| | time_step = np.array([0], dtype=np.int32) |
| |
|
| |
|
| | device_type = None |
| |
|
| | print("\n\nRun F5-TTS.") |
| | start_count = time.time() |
| | noise, rope_cos_q, rope_sin_q, rope_cos_k, rope_sin_k, cat_mel_text, cat_mel_text_drop, ref_signal_len = ort_session_A.run( |
| | [out_name_A0, out_name_A1, out_name_A2, out_name_A3, out_name_A4, out_name_A5, out_name_A6, out_name_A7], |
| | { |
| | in_name_A0: audio, |
| | in_name_A1: text_ids, |
| | in_name_A2: max_duration |
| | }) |
| | end_count = time.time() |
| | print(f"\nPrepare input data time cost: {end_count - start_count:.3f} seconds") |
| |
|
| | print("NFE_STEP: 0") |
| | for i in tqdm(range(0, NFE_STEP - 1, FUSE_NFE)): |
| | noise, time_step = ort_session_B.run( |
| | [out_name_B0, out_name_B1], |
| | { |
| | in_name_B0: noise, |
| | in_name_B1: rope_cos_q, |
| | in_name_B2: rope_sin_q, |
| | in_name_B3: rope_cos_k, |
| | in_name_B4: rope_sin_k, |
| | in_name_B5: cat_mel_text, |
| | in_name_B6: cat_mel_text_drop, |
| | in_name_B7: time_step |
| | }) |
| | print(f"NFE_STEP: {i + FUSE_NFE}") |
| |
|
| | start_count = time.time() |
| | generated_signal = ort_session_C.run( |
| | [out_name_C0], |
| | { |
| | in_name_C0: noise, |
| | in_name_C1: ref_signal_len |
| | })[0] |
| | end_count = time.time() |
| | print(f"Decode time cost: {end_count - start_count:.3f} seconds") |
| |
|
| | |
| | print("\nAdjusting audio speed...") |
| | |
| | generated_signal_flat = generated_signal.flatten().astype(np.float32) |
| |
|
| | |
| | generated_signal_int16 = normalize_to_int16(generated_signal_flat) |
| |
|
| | |
| | try: |
| | segment = AudioSegment( |
| | generated_signal_int16.tobytes(), |
| | frame_rate=MODEL_SAMPLE_RATE, |
| | sample_width=generated_signal_int16.dtype.itemsize, |
| | channels=1 |
| | ) |
| |
|
| | |
| | |
| | if speed_adjustment_ratio != 1.0: |
| | print(f"Applying speed change with factor: {speed_adjustment_ratio:.3f}") |
| | adjusted_segment = segment.speedup(playback_speed=speed_adjustment_ratio) |
| | else: |
| | print("No speed adjustment needed.") |
| | adjusted_segment = segment |
| |
|
| | |
| | adjusted_segment.export(generated_audio, format="wav") |
| | print(f"\nSpeed-adjusted audio saved to: {generated_audio}") |
| |
|
| | except Exception as e: |
| | print(f"Error during audio speed adjustment or saving: {e}") |
| | print("Saving original audio instead.") |
| | |
| | sf.write(generated_audio, generated_signal_flat, MODEL_SAMPLE_RATE, format='WAVEX') |
| |
|
| | print(f"\nAudio generation and processing is complete.\n\nONNXRuntime Time Cost in Seconds:\n{end_count - start_count:.3f}") |
| |
|