OpenVoice / OpenVoice-RKNN2 /test_rknn.py
niobures's picture
OpenVoice & OpenVoiceV2
e0b3bf7 verified
from typing import Callable
import numpy as np
import onnxruntime as ort
import os
from rknnlite.api import RKNNLite
import json
import os
import time
class HParams:
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
v = HParams(**v)
self[k] = v
def keys(self):
return self.__dict__.keys()
def items(self):
return self.__dict__.items()
def values(self):
return self.__dict__.values()
def __len__(self):
return len(self.__dict__)
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
return setattr(self, key, value)
def __contains__(self, key):
return key in self.__dict__
def __repr__(self):
return self.__dict__.__repr__()
@staticmethod
def load_from_file(file_path:str):
if not os.path.exists(file_path):
raise FileNotFoundError(f"Can not found the configuration file \"{file_path}\"")
with open(file_path, "r", encoding="utf-8") as f:
hps = json.load(f)
return HParams(**hps)
class BaseClassForOnnxInfer():
@staticmethod
def create_onnx_infer(infer_factor:Callable, onnx_model_path:str, providers:list, session_options:ort.SessionOptions = None, onnx_params:dict = None):
if not os.path.exists(onnx_model_path):
raise FileNotFoundError(f"Can not found the onnx model file \"{onnx_model_path}\"")
session = ort.InferenceSession(onnx_model_path, sess_options=BaseClassForOnnxInfer.adjust_onnx_session_options(session_options), providers=providers, **(onnx_params or {}))
fn = infer_factor(session)
fn.__session = session
return fn
@staticmethod
def get_def_onnx_session_options():
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
return session_options
@staticmethod
def adjust_onnx_session_options(session_options:ort.SessionOptions = None):
return session_options or BaseClassForOnnxInfer.get_def_onnx_session_options()
class OpenVoiceToneClone_ONNXRKNN(BaseClassForOnnxInfer):
PreferredProviders = ['CPUExecutionProvider']
def __init__(self, model_path, execution_provider:str = None, verbose:bool = False, onnx_session_options:ort.SessionOptions = None, onnx_params:dict = None, target_length:int = 1024):
'''
Create the instance of the tone cloner
Args:
model_path (str): The path of the folder which contains the model
execution_provider (str): The provider that onnxruntime used. Such as CPUExecutionProvider, CUDAExecutionProvider, etc. Or you can use CPU, CUDA as short one. If it is None, the constructor will choose a best one automaticlly
verbose (bool): Set True to show more detail informations when working
onnx_session_options (onnxruntime.SessionOptions): The custom options for onnx session
onnx_params (dict): Other parameters you want to pass to the onnxruntime.InferenceSession constructor
target_length (int): The target length for padding/truncating spectrogram, defaults to 1024
Returns:
OpenVoiceToneClone_ONNX: The instance of the tone cloner
'''
self.__verbose = verbose
self.__target_length = target_length
if verbose:
print("Loading the configuration...")
config_path = os.path.join(model_path, "configuration.json")
self.__hparams = HParams.load_from_file(config_path)
execution_provider = f"{execution_provider}ExecutionProvider" if (execution_provider is not None) and (not execution_provider.endswith("ExecutionProvider")) else execution_provider
available_providers = ort.get_available_providers()
# self.__execution_providers = [execution_provider if execution_provider in available_providers else next((provider for provider in MeloTTS_ONNX.PreferredProviders if provider in available_providers), 'CPUExecutionProvider')]
self.__execution_providers = ['CPUExecutionProvider']
if verbose:
print("Creating onnx session for tone color extractor...")
def se_infer_factor(session):
return lambda **kwargs: session.run(None, kwargs)[0]
self.__se_infer = self.create_onnx_infer(se_infer_factor, os.path.join(model_path, "tone_color_extract_model.onnx"), self.__execution_providers, onnx_session_options, onnx_params)
if verbose:
print("Creating RKNNLite session for tone clone ...")
# 初始化RKNNLite
self.__tc_rknn = RKNNLite(verbose=verbose)
# 加载RKNN模型
ret = self.__tc_rknn.load_rknn(os.path.join(model_path, "tone_clone_model.rknn"))
if ret != 0:
raise RuntimeError("Failed to load RKNN model")
# 初始化运行时
ret = self.__tc_rknn.init_runtime()
if ret != 0:
raise RuntimeError("Failed to init RKNN runtime")
def __del__(self):
"""释放RKNN资源"""
if hasattr(self, '_OpenVoiceToneClone_ONNXRKNN__tc_rknn'):
self.__tc_rknn.release()
hann_window = {}
def __spectrogram_numpy(self, y, n_fft, sampling_rate, hop_size, win_size, onesided=True):
if self.__verbose:
if np.min(y) < -1.1:
print("min value is ", np.min(y))
if np.max(y) > 1.1:
print("max value is ", np.max(y))
# 填充
y = np.pad(
y,
int((n_fft - hop_size) / 2),
mode="reflect",
)
# 生成汉宁窗
win_key = f"{str(y.dtype)}-{win_size}"
if True or win_key not in hann_window:
OpenVoiceToneClone_ONNXRKNN.hann_window[win_key] = np.hanning(win_size + 1)[:-1].astype(y.dtype)
window = OpenVoiceToneClone_ONNXRKNN.hann_window[win_key]
# 短时傅里叶变换
y_len = y.shape[0]
win_len = window.shape[0]
count = int((y_len - win_len) // hop_size) + 1
spec = np.empty((count, int(win_len / 2) + 1 if onesided else (int(win_len / 2) + 1) * 2, 2))
start = 0
end = start + win_len
idx = 0
while end <= y_len:
segment = y[start:end]
frame = segment * window
step_result = np.fft.rfft(frame) if onesided else np.fft.fft(frame)
spec[idx] = np.column_stack((step_result.real, step_result.imag))
start = start + hop_size
end = start + win_len
idx += 1
# 合并实部虚部
spec = np.sqrt(np.sum(np.square(spec), axis=-1) + 1e-6)
return np.array([spec], dtype=np.float32)
def extract_tone_color(self, audio:np.array):
'''
Extract the tone color from an audio
Args:
audio (numpy.array): The data of the audio
Returns:
numpy.array: The tone color vector
'''
hps = self.__hparams
y = self.to_mono(audio.astype(np.float32))
spec = self.__spectrogram_numpy(y, hps.data.filter_length,
hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
)
if self.__verbose:
print("spec shape", spec.shape)
return self.__se_infer(input=spec).reshape(1,256,1)
def mix_tone_color(self, colors:list):
'''
Mix multi tone colors to a single one
Args:
color (list[numpy.array]): The list of the tone colors you want to mix. Each element should be the result of extract_tone_color.
Returns:
numpy.array: The tone color vector
'''
return np.stack(colors).mean(axis=0)
def tone_clone(self, audio:np.array, target_tone_color:np.array, tau=0.3):
'''
Clone the tone
Args:
audio (numpy.array): The data of the audio that will be changed the tone
target_tone_color (numpy.array): The tone color that you want to clone. It should be the result of the extract_tone_color or mix_tone_color.
tau (float):
Returns:
numpy.array: The dest audio
'''
assert (target_tone_color.shape == (1,256,1)), "The target tone color must be an array with shape (1,256,1)"
hps = self.__hparams
src = self.to_mono(audio.astype(np.float32))
src = self.__spectrogram_numpy(src, hps.data.filter_length,
hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
)
src_tone = self.__se_infer(input=src).reshape(1,256,1)
src = np.transpose(src, (0, 2, 1))
# 记录原始长度
original_length = src.shape[2]
# Pad或截断到固定长度
if original_length > self.__target_length:
if self.__verbose:
print(f"Input length {original_length} exceeds target length {self.__target_length}, truncating...")
src = src[:, :, :self.__target_length]
elif original_length < self.__target_length:
if self.__verbose:
print(f"Input length {original_length} is less than target length {self.__target_length}, padding...")
pad_width = ((0, 0), (0, 0), (0, self.__target_length - original_length))
src = np.pad(src, pad_width, mode='constant', constant_values=0)
src_length = np.array([self.__target_length], dtype=np.int64) # 使用固定长度
if self.__verbose:
print("src shape", src.shape)
print("src_length shape", src_length.shape)
print("src_tone shape", src_tone.shape)
print("target_tone_color shape", target_tone_color.shape)
print("tau", tau)
# 准备RKNNLite的输入
inputs = [
src,
src_length,
src_tone,
target_tone_color,
np.array([tau], dtype=np.float32)
]
# 使用RKNNLite进行推理
outputs = self.__tc_rknn.inference(inputs=inputs)
res = outputs[0][0, 0] # 获取第一个输出的第一个样本
generated_multiplier = 262144 / 1024
# 如果原始输入较短,则截取掉padding部分
if original_length < self.__target_length:
res = res[:int(original_length * generated_multiplier)]
if self.__verbose:
print("res shape", res.shape)
return res
def to_mono(self, audio:np.array):
'''
Change the audio to be a mono audio
Args:
audio (numpy.array): The source audio
Returns:
numpy.array: The mono audio data
'''
return np.mean(audio, axis=1) if len(audio.shape) > 1 else audio
def resample(self, audio:np.array, original_rate:int):
'''
Resample the audio to match the model. It is used for changing the sample rate of the audio.
Args:
audio (numpy.array): The source audio you want to resample.
original_rate (int): The original sample rate of the source audio
Returns:
numpy.array: The dest data of the audio after resample
'''
audio = self.to_mono(audio)
target_rate = self.__hparams.data.sampling_rate
duration = audio.shape[0] / original_rate
target_length = int(duration * target_rate)
time_original = np.linspace(0, duration, num=audio.shape[0])
time_target = np.linspace(0, duration, num=target_length)
resampled_data = np.interp(time_target, time_original, audio)
return resampled_data
@property
def sample_rate(self):
'''
The sample rate of the tone cloning result
'''
return self.__hparams.data.sampling_rate
tc = OpenVoiceToneClone_ONNXRKNN(".",verbose=True)
import soundfile
tgt = soundfile.read("target.wav", dtype='float32')
tgt = tc.resample(tgt[0], tgt[1])
# 计时extract_tone_color
start_time = time.time()
tgt_tone_color = tc.extract_tone_color(tgt)
extract_time = time.time() - start_time
print(f"提取音色特征耗时: {extract_time:.2f}秒")
src = soundfile.read("src2.wav", dtype='float32')
src = tc.resample(src[0], src[1])
# 计时tone_clone
start_time = time.time()
result = tc.tone_clone(src, tgt_tone_color)
clone_time = time.time() - start_time
print(f"克隆音色耗时: {clone_time:.2f}秒")
soundfile.write("result.wav", result, tc.sample_rate)