happyme531 commited on
Commit
cc403c3
·
verified ·
1 Parent(s): beb0993

Upload 10 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ result.wav filter=lfs diff=lfs merge=lfs -text
37
+ src2.wav filter=lfs diff=lfs merge=lfs -text
38
+ target.wav filter=lfs diff=lfs merge=lfs -text
39
+ tone_clone_model.rknn filter=lfs diff=lfs merge=lfs -text
configuration.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_version_": "v2",
3
+ "data": {
4
+ "sampling_rate": 22050,
5
+ "filter_length": 1024,
6
+ "hop_length": 256,
7
+ "win_length": 1024,
8
+ "n_speakers": 0
9
+ },
10
+ "model": {
11
+ "zero_g": true,
12
+ "inter_channels": 192,
13
+ "hidden_channels": 192,
14
+ "filter_channels": 768,
15
+ "n_heads": 2,
16
+ "n_layers": 6,
17
+ "kernel_size": 3,
18
+ "p_dropout": 0.1,
19
+ "resblock": "1",
20
+ "resblock_kernel_sizes": [
21
+ 3,
22
+ 7,
23
+ 11
24
+ ],
25
+ "resblock_dilation_sizes": [
26
+ [
27
+ 1,
28
+ 3,
29
+ 5
30
+ ],
31
+ [
32
+ 1,
33
+ 3,
34
+ 5
35
+ ],
36
+ [
37
+ 1,
38
+ 3,
39
+ 5
40
+ ]
41
+ ],
42
+ "upsample_rates": [
43
+ 8,
44
+ 8,
45
+ 2,
46
+ 2
47
+ ],
48
+ "upsample_initial_channel": 512,
49
+ "upsample_kernel_sizes": [
50
+ 16,
51
+ 16,
52
+ 4,
53
+ 4
54
+ ],
55
+ "gin_channels": 256
56
+ }
57
+ }
convert_rknn.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import datetime
5
+ import argparse
6
+ from rknn.api import RKNN
7
+ from sys import exit
8
+
9
+ # 模型配置
10
+ MODELS = {
11
+ 'tone_clone': 'tone_clone_model.onnx',
12
+ 'tone_color_extract': 'tone_color_extract_model.onnx',
13
+ }
14
+
15
+ TARGET_AUDIO_LENS = [1024]
16
+
17
+ SOURCE_AUDIO_LENS = [1024]
18
+
19
+ AUDIO_DIM = 513
20
+
21
+ QUANTIZE=False
22
+ detailed_performance_log = True
23
+
24
+ def convert_model(model_type):
25
+ """转换指定类型的模型到RKNN格式"""
26
+ if model_type not in MODELS:
27
+ print(f"错误: 不支持的模型类型 {model_type}")
28
+ return False
29
+
30
+ onnx_model = MODELS[model_type]
31
+ rknn_model = onnx_model.replace(".onnx",".rknn")
32
+
33
+ if model_type == 'tone_clone':
34
+ shapes = [
35
+ [
36
+ [1, 513, target_audio_len], # audio
37
+ [1], # audio_length
38
+ [1, 256, 1], # src_tone
39
+ [1, 256, 1], # dest_tone
40
+ [1], # tau
41
+ ] for target_audio_len in TARGET_AUDIO_LENS
42
+ ]
43
+ elif model_type == 'tone_color_extract':
44
+ shapes = [
45
+ [
46
+ [1, source_audio_len, 513], # audio
47
+ ] for source_audio_len in SOURCE_AUDIO_LENS
48
+ ]
49
+ # shapes = None
50
+
51
+ timedate_iso = datetime.datetime.now().isoformat()
52
+
53
+ rknn = RKNN(verbose=True)
54
+ rknn.config(
55
+ quantized_dtype='w8a8',
56
+ quantized_algorithm='normal',
57
+ quantized_method='channel',
58
+ quantized_hybrid_level=0,
59
+ target_platform='rk3588',
60
+ quant_img_RGB2BGR = False,
61
+ float_dtype='float16',
62
+ optimization_level=3,
63
+ custom_string=f"converted by: qq: 232004040, email: 2302004040@qq.com at {timedate_iso}",
64
+ remove_weight=False,
65
+ compress_weight=False,
66
+ inputs_yuv_fmt=None,
67
+ single_core_mode=False,
68
+ dynamic_input=shapes,
69
+ model_pruning=False,
70
+ op_target=None,
71
+ quantize_weight=False,
72
+ remove_reshape=False,
73
+ sparse_infer=False,
74
+ enable_flash_attention=False,
75
+ # disable_rules=['convert_gemm_by_exmatmul']
76
+ )
77
+
78
+ print(f"开始转换 {model_type} 模型...")
79
+ ret = rknn.load_onnx(model=onnx_model)
80
+ if ret != 0:
81
+ print("加载ONNX模型失败")
82
+ return False
83
+
84
+ ret = rknn.build(do_quantization=False, rknn_batch_size=None)
85
+ if ret != 0:
86
+ print("构建RKNN模型失败")
87
+ return False
88
+
89
+ ret = rknn.export_rknn(rknn_model)
90
+ if ret != 0:
91
+ print("导出RKNN模型失败")
92
+ return False
93
+
94
+ print(f"成功转换模型: {rknn_model}")
95
+ return True
96
+
97
+ def main():
98
+ parser = argparse.ArgumentParser(description='转换ONNX模型到RKNN格式')
99
+ parser.add_argument('model_type', nargs='?', default='all',
100
+ choices=['all', 'tone_clone', 'tone_color_extract'],
101
+ help='要转换的模型类型 (默认: all)')
102
+
103
+ args = parser.parse_args()
104
+
105
+ if args.model_type == 'all':
106
+ # 转换所有模型
107
+ for model_type in MODELS.keys():
108
+ if not convert_model(model_type):
109
+ print(f"转换 {model_type} 失败")
110
+ else:
111
+ # 转换指定模型
112
+ if not convert_model(args.model_type):
113
+ print(f"转换 {args.model_type} 失败")
114
+
115
+ if __name__ == '__main__':
116
+ main()
117
+
118
+
export_onnx.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from openvoice.api import ToneColorConverter
4
+ from openvoice.models import SynthesizerTrn
5
+ import os
6
+
7
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
8
+
9
+ class ToneColorExtractWrapper(nn.Module):
10
+ def __init__(self, model):
11
+ super().__init__()
12
+ self.model = model
13
+
14
+ def forward(self, audio):
15
+ # audio: [1, source_audio_len, 513]
16
+ # 将mel谱图转置为模型需要的格式 [1, 513, source_audio_len]
17
+ audio = audio.contiguous()
18
+ # 提取声纹
19
+ g = self.model.ref_enc(audio)
20
+ # 扩展最后一维
21
+ # g = g.unsqueeze(-1) # [1, 256, 1]
22
+ return g
23
+
24
+ class ToneCloneWrapper(nn.Module):
25
+ def __init__(self, model):
26
+ super().__init__()
27
+ self.model = model
28
+
29
+ def forward(self, audio, audio_lengths, src_tone, dest_tone, tau):
30
+ # 确保张量连续
31
+ audio = audio.contiguous()
32
+ src_tone = src_tone.contiguous()
33
+ dest_tone = dest_tone.contiguous()
34
+
35
+ # 语音转换
36
+ o_hat, _, _ = self.model.voice_conversion(
37
+ audio,
38
+ audio_lengths,
39
+ sid_src=src_tone,
40
+ sid_tgt=dest_tone,
41
+ tau=tau[0]
42
+ )
43
+ return o_hat
44
+
45
+ def export_models(ckpt_path, output_dir, target_audio_lens, source_audio_lens):
46
+ """
47
+ 导出音色提取和克隆模型为ONNX格式
48
+
49
+ Args:
50
+ ckpt_path: 模型检查点路径
51
+ output_dir: 输出目录
52
+ target_audio_lens: 目标音频长度列表
53
+ source_audio_lens: 源音频长度列表
54
+ """
55
+
56
+ # 加载模型
57
+ device = "cpu"
58
+ converter = ToneColorConverter(f'{ckpt_path}/config.json', device=device)
59
+ converter.load_ckpt(f'{ckpt_path}/checkpoint.pth')
60
+
61
+ # 创建输出目录
62
+ os.makedirs(output_dir, exist_ok=True)
63
+
64
+ # 导出音色提取模型
65
+ extract_wrapper = ToneColorExtractWrapper(converter.model)
66
+ extract_wrapper.eval()
67
+
68
+ for source_len in source_audio_lens:
69
+ dummy_input = torch.randn(1, source_len, 513).contiguous()
70
+ output_path = f"{output_dir}/tone_color_extract_model.onnx"
71
+
72
+ torch.onnx.export(
73
+ extract_wrapper,
74
+ dummy_input,
75
+ output_path,
76
+ input_names=['input'],
77
+ output_names=['tone_embedding'],
78
+ dynamic_axes={
79
+ 'input': {1: 'source_audio_len'},
80
+ },
81
+ opset_version=11,
82
+ do_constant_folding=True,
83
+ verbose=True
84
+ )
85
+ print(f"Exported tone extract model to {output_path}")
86
+
87
+ # 导出音色克隆模型
88
+ clone_wrapper = ToneCloneWrapper(converter.model)
89
+ clone_wrapper.eval()
90
+
91
+ for target_len in target_audio_lens:
92
+ dummy_inputs = (
93
+ torch.randn(1, 513, target_len).contiguous(), # audio
94
+ torch.LongTensor([target_len]), # audio_lengths
95
+ torch.randn(1, 256, 1).contiguous(), # src_tone
96
+ torch.randn(1, 256, 1).contiguous(), # dest_tone
97
+ torch.FloatTensor([0.3]) # tau
98
+ )
99
+
100
+ output_path = f"{output_dir}/tone_clone_model.onnx"
101
+
102
+ torch.onnx.export(
103
+ clone_wrapper,
104
+ dummy_inputs,
105
+ output_path,
106
+ input_names=['audio', 'audio_length', 'src_tone', 'dest_tone', 'tau'],
107
+ output_names=['converted_audio'],
108
+ dynamic_axes={
109
+ 'audio': {2: 'target_audio_len'},
110
+ },
111
+ opset_version=17,
112
+ do_constant_folding=True,
113
+ verbose=True
114
+ )
115
+ print(f"Exported tone clone model to {output_path}")
116
+
117
+ if __name__ == "__main__":
118
+ # 示例用法
119
+ TARGET_AUDIO_LENS = [1024] # 根据需要设置目标长度
120
+ SOURCE_AUDIO_LENS = [1024] # 根据需要设置源长度
121
+
122
+ export_models(
123
+ ckpt_path="checkpoints_v2/converter",
124
+ output_dir="onnx_models",
125
+ target_audio_lens=TARGET_AUDIO_LENS,
126
+ source_audio_lens=SOURCE_AUDIO_LENS
127
+ )
result.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d11ad289cc5014994086548874fd145ac67c41eb9b91fdd822ad6bd05a40c90f
3
+ size 393260
src2.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:baf4ce666c5fa88e052381e0c33543be3015bf2f47154ac3925ee67c963c0a12
3
+ size 1712078
target.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c63d1b5cb444f3611a271d1c24d04363f5bdd73fb5745bc6b61e1c925a8f6084
3
+ size 2165838
test_rknn.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+ import numpy as np
3
+ import onnxruntime as ort
4
+ import os
5
+ from rknnlite.api import RKNNLite
6
+ import json
7
+ import os
8
+ import time
9
+
10
+ class HParams:
11
+ def __init__(self, **kwargs):
12
+ for k, v in kwargs.items():
13
+ if type(v) == dict:
14
+ v = HParams(**v)
15
+ self[k] = v
16
+
17
+ def keys(self):
18
+ return self.__dict__.keys()
19
+
20
+ def items(self):
21
+ return self.__dict__.items()
22
+
23
+ def values(self):
24
+ return self.__dict__.values()
25
+
26
+ def __len__(self):
27
+ return len(self.__dict__)
28
+
29
+ def __getitem__(self, key):
30
+ return getattr(self, key)
31
+
32
+ def __setitem__(self, key, value):
33
+ return setattr(self, key, value)
34
+
35
+ def __contains__(self, key):
36
+ return key in self.__dict__
37
+
38
+ def __repr__(self):
39
+ return self.__dict__.__repr__()
40
+
41
+ @staticmethod
42
+ def load_from_file(file_path:str):
43
+ if not os.path.exists(file_path):
44
+ raise FileNotFoundError(f"Can not found the configuration file \"{file_path}\"")
45
+ with open(file_path, "r", encoding="utf-8") as f:
46
+ hps = json.load(f)
47
+ return HParams(**hps)
48
+
49
+ class BaseClassForOnnxInfer():
50
+ @staticmethod
51
+ def create_onnx_infer(infer_factor:Callable, onnx_model_path:str, providers:list, session_options:ort.SessionOptions = None, onnx_params:dict = None):
52
+ if not os.path.exists(onnx_model_path):
53
+ raise FileNotFoundError(f"Can not found the onnx model file \"{onnx_model_path}\"")
54
+ session = ort.InferenceSession(onnx_model_path, sess_options=BaseClassForOnnxInfer.adjust_onnx_session_options(session_options), providers=providers, **(onnx_params or {}))
55
+ fn = infer_factor(session)
56
+ fn.__session = session
57
+ return fn
58
+
59
+ @staticmethod
60
+ def get_def_onnx_session_options():
61
+ session_options = ort.SessionOptions()
62
+ session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
63
+ return session_options
64
+
65
+ @staticmethod
66
+ def adjust_onnx_session_options(session_options:ort.SessionOptions = None):
67
+ return session_options or BaseClassForOnnxInfer.get_def_onnx_session_options()
68
+
69
+ class OpenVoiceToneClone_ONNXRKNN(BaseClassForOnnxInfer):
70
+
71
+ PreferredProviders = ['CPUExecutionProvider']
72
+
73
+ def __init__(self, model_path, execution_provider:str = None, verbose:bool = False, onnx_session_options:ort.SessionOptions = None, onnx_params:dict = None, target_length:int = 1024):
74
+ '''
75
+ Create the instance of the tone cloner
76
+
77
+ Args:
78
+ model_path (str): The path of the folder which contains the model
79
+ execution_provider (str): The provider that onnxruntime used. Such as CPUExecutionProvider, CUDAExecutionProvider, etc. Or you can use CPU, CUDA as short one. If it is None, the constructor will choose a best one automaticlly
80
+ verbose (bool): Set True to show more detail informations when working
81
+ onnx_session_options (onnxruntime.SessionOptions): The custom options for onnx session
82
+ onnx_params (dict): Other parameters you want to pass to the onnxruntime.InferenceSession constructor
83
+ target_length (int): The target length for padding/truncating spectrogram, defaults to 1024
84
+
85
+ Returns:
86
+ OpenVoiceToneClone_ONNX: The instance of the tone cloner
87
+ '''
88
+ self.__verbose = verbose
89
+ self.__target_length = target_length
90
+
91
+ if verbose:
92
+ print("Loading the configuration...")
93
+ config_path = os.path.join(model_path, "configuration.json")
94
+ self.__hparams = HParams.load_from_file(config_path)
95
+
96
+ execution_provider = f"{execution_provider}ExecutionProvider" if (execution_provider is not None) and (not execution_provider.endswith("ExecutionProvider")) else execution_provider
97
+ available_providers = ort.get_available_providers()
98
+ # self.__execution_providers = [execution_provider if execution_provider in available_providers else next((provider for provider in MeloTTS_ONNX.PreferredProviders if provider in available_providers), 'CPUExecutionProvider')]
99
+ self.__execution_providers = ['CPUExecutionProvider']
100
+ if verbose:
101
+ print("Creating onnx session for tone color extractor...")
102
+ def se_infer_factor(session):
103
+ return lambda **kwargs: session.run(None, kwargs)[0]
104
+ self.__se_infer = self.create_onnx_infer(se_infer_factor, os.path.join(model_path, "tone_color_extract_model.onnx"), self.__execution_providers, onnx_session_options, onnx_params)
105
+
106
+ if verbose:
107
+ print("Creating RKNNLite session for tone clone ...")
108
+ # 初始化RKNNLite
109
+ self.__tc_rknn = RKNNLite(verbose=verbose)
110
+ # 加载RKNN模型
111
+ ret = self.__tc_rknn.load_rknn(os.path.join(model_path, "tone_clone_model.rknn"))
112
+ if ret != 0:
113
+ raise RuntimeError("Failed to load RKNN model")
114
+ # 初始化运行时
115
+ ret = self.__tc_rknn.init_runtime()
116
+ if ret != 0:
117
+ raise RuntimeError("Failed to init RKNN runtime")
118
+
119
+ def __del__(self):
120
+ """释放RKNN资源"""
121
+ if hasattr(self, '_OpenVoiceToneClone_ONNXRKNN__tc_rknn'):
122
+ self.__tc_rknn.release()
123
+
124
+ hann_window = {}
125
+
126
+ def __spectrogram_numpy(self, y, n_fft, sampling_rate, hop_size, win_size, onesided=True):
127
+ if self.__verbose:
128
+ if np.min(y) < -1.1:
129
+ print("min value is ", np.min(y))
130
+ if np.max(y) > 1.1:
131
+ print("max value is ", np.max(y))
132
+
133
+ # 填充
134
+ y = np.pad(
135
+ y,
136
+ int((n_fft - hop_size) / 2),
137
+ mode="reflect",
138
+ )
139
+
140
+ # 生成汉宁窗
141
+ win_key = f"{str(y.dtype)}-{win_size}"
142
+ if True or win_key not in hann_window:
143
+ OpenVoiceToneClone_ONNXRKNN.hann_window[win_key] = np.hanning(win_size + 1)[:-1].astype(y.dtype)
144
+ window = OpenVoiceToneClone_ONNXRKNN.hann_window[win_key]
145
+
146
+ # 短时傅里叶变换
147
+ y_len = y.shape[0]
148
+ win_len = window.shape[0]
149
+ count = int((y_len - win_len) // hop_size) + 1
150
+ spec = np.empty((count, int(win_len / 2) + 1 if onesided else (int(win_len / 2) + 1) * 2, 2))
151
+ start = 0
152
+ end = start + win_len
153
+ idx = 0
154
+ while end <= y_len:
155
+ segment = y[start:end]
156
+ frame = segment * window
157
+ step_result = np.fft.rfft(frame) if onesided else np.fft.fft(frame)
158
+ spec[idx] = np.column_stack((step_result.real, step_result.imag))
159
+ start = start + hop_size
160
+ end = start + win_len
161
+ idx += 1
162
+
163
+ # 合并实部虚部
164
+ spec = np.sqrt(np.sum(np.square(spec), axis=-1) + 1e-6)
165
+
166
+ return np.array([spec], dtype=np.float32)
167
+
168
+ def extract_tone_color(self, audio:np.array):
169
+ '''
170
+ Extract the tone color from an audio
171
+
172
+ Args:
173
+ audio (numpy.array): The data of the audio
174
+
175
+ Returns:
176
+ numpy.array: The tone color vector
177
+ '''
178
+ hps = self.__hparams
179
+ y = self.to_mono(audio.astype(np.float32))
180
+ spec = self.__spectrogram_numpy(y, hps.data.filter_length,
181
+ hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
182
+ )
183
+
184
+ if self.__verbose:
185
+ print("spec shape", spec.shape)
186
+ return self.__se_infer(input=spec).reshape(1,256,1)
187
+
188
+ def mix_tone_color(self, colors:list):
189
+ '''
190
+ Mix multi tone colors to a single one
191
+
192
+ Args:
193
+ color (list[numpy.array]): The list of the tone colors you want to mix. Each element should be the result of extract_tone_color.
194
+
195
+ Returns:
196
+ numpy.array: The tone color vector
197
+ '''
198
+ return np.stack(colors).mean(axis=0)
199
+
200
+ def tone_clone(self, audio:np.array, target_tone_color:np.array, tau=0.3):
201
+ '''
202
+ Clone the tone
203
+
204
+ Args:
205
+ audio (numpy.array): The data of the audio that will be changed the tone
206
+ target_tone_color (numpy.array): The tone color that you want to clone. It should be the result of the extract_tone_color or mix_tone_color.
207
+ tau (float):
208
+
209
+ Returns:
210
+ numpy.array: The dest audio
211
+ '''
212
+ assert (target_tone_color.shape == (1,256,1)), "The target tone color must be an array with shape (1,256,1)"
213
+ hps = self.__hparams
214
+ src = self.to_mono(audio.astype(np.float32))
215
+ src = self.__spectrogram_numpy(src, hps.data.filter_length,
216
+ hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
217
+ )
218
+ src_tone = self.__se_infer(input=src).reshape(1,256,1)
219
+
220
+ src = np.transpose(src, (0, 2, 1))
221
+ # 记录原始长度
222
+ original_length = src.shape[2]
223
+
224
+ # Pad或截断到固定长度
225
+ if original_length > self.__target_length:
226
+ if self.__verbose:
227
+ print(f"Input length {original_length} exceeds target length {self.__target_length}, truncating...")
228
+ src = src[:, :, :self.__target_length]
229
+ elif original_length < self.__target_length:
230
+ if self.__verbose:
231
+ print(f"Input length {original_length} is less than target length {self.__target_length}, padding...")
232
+ pad_width = ((0, 0), (0, 0), (0, self.__target_length - original_length))
233
+ src = np.pad(src, pad_width, mode='constant', constant_values=0)
234
+
235
+ src_length = np.array([self.__target_length], dtype=np.int64) # 使用固定长度
236
+
237
+ if self.__verbose:
238
+ print("src shape", src.shape)
239
+ print("src_length shape", src_length.shape)
240
+ print("src_tone shape", src_tone.shape)
241
+ print("target_tone_color shape", target_tone_color.shape)
242
+ print("tau", tau)
243
+
244
+ # 准备RKNNLite的输入
245
+ inputs = [
246
+ src,
247
+ src_length,
248
+ src_tone,
249
+ target_tone_color,
250
+ np.array([tau], dtype=np.float32)
251
+ ]
252
+
253
+ # 使用RKNNLite进行推理
254
+ outputs = self.__tc_rknn.inference(inputs=inputs)
255
+ res = outputs[0][0, 0] # 获取第一个输出的第一个样本
256
+
257
+ generated_multiplier = 262144 / 1024
258
+ # 如果原始输入较短,则截取掉padding部分
259
+ if original_length < self.__target_length:
260
+ res = res[:int(original_length * generated_multiplier)]
261
+
262
+ if self.__verbose:
263
+ print("res shape", res.shape)
264
+ return res
265
+
266
+ def to_mono(self, audio:np.array):
267
+ '''
268
+ Change the audio to be a mono audio
269
+
270
+ Args:
271
+ audio (numpy.array): The source audio
272
+
273
+ Returns:
274
+ numpy.array: The mono audio data
275
+ '''
276
+ return np.mean(audio, axis=1) if len(audio.shape) > 1 else audio
277
+
278
+ def resample(self, audio:np.array, original_rate:int):
279
+ '''
280
+ Resample the audio to match the model. It is used for changing the sample rate of the audio.
281
+
282
+ Args:
283
+ audio (numpy.array): The source audio you want to resample.
284
+ original_rate (int): The original sample rate of the source audio
285
+
286
+ Returns:
287
+ numpy.array: The dest data of the audio after resample
288
+ '''
289
+ audio = self.to_mono(audio)
290
+ target_rate = self.__hparams.data.sampling_rate
291
+ duration = audio.shape[0] / original_rate
292
+ target_length = int(duration * target_rate)
293
+ time_original = np.linspace(0, duration, num=audio.shape[0])
294
+ time_target = np.linspace(0, duration, num=target_length)
295
+ resampled_data = np.interp(time_target, time_original, audio)
296
+ return resampled_data
297
+
298
+ @property
299
+ def sample_rate(self):
300
+ '''
301
+ The sample rate of the tone cloning result
302
+ '''
303
+ return self.__hparams.data.sampling_rate
304
+
305
+
306
+ tc = OpenVoiceToneClone_ONNXRKNN(".",verbose=True)
307
+ import soundfile
308
+
309
+ tgt = soundfile.read("target.wav", dtype='float32')
310
+ tgt = tc.resample(tgt[0], tgt[1])
311
+
312
+ # 计时extract_tone_color
313
+ start_time = time.time()
314
+ tgt_tone_color = tc.extract_tone_color(tgt)
315
+ extract_time = time.time() - start_time
316
+ print(f"提取音色特征耗时: {extract_time:.2f}秒")
317
+
318
+ src = soundfile.read("src2.wav", dtype='float32')
319
+ src = tc.resample(src[0], src[1])
320
+
321
+ # 计时tone_clone
322
+ start_time = time.time()
323
+ result = tc.tone_clone(src, tgt_tone_color)
324
+ clone_time = time.time() - start_time
325
+ print(f"克隆音色耗时: {clone_time:.2f}秒")
326
+
327
+ soundfile.write("result.wav", result, tc.sample_rate)
tone_clone_model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:896195b84b0cb87a828bb8cab06577e9c024356bc9727b1a8f4174154bc0affa
3
+ size 157196170
tone_clone_model.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7cd7dc3385c55ca610580edaba263510091314be35ae4688a1c076afe9e5d84a
3
+ size 108102277
tone_color_extract_model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e91c2cb696e199d2519ed8b62ca6e3c8e42cb99ca13955dd6e188051486e681c
3
+ size 3364792