File size: 12,764 Bytes
cc403c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
from typing import Callable
import numpy as np
import onnxruntime as ort
import os
from rknnlite.api import RKNNLite
import json
import os
import time

class HParams:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            if type(v) == dict:
                v = HParams(**v)
            self[k] = v

    def keys(self):
        return self.__dict__.keys()

    def items(self):
        return self.__dict__.items()

    def values(self):
        return self.__dict__.values()

    def __len__(self):
        return len(self.__dict__)

    def __getitem__(self, key):
        return getattr(self, key)

    def __setitem__(self, key, value):
        return setattr(self, key, value)

    def __contains__(self, key):
        return key in self.__dict__

    def __repr__(self):
        return self.__dict__.__repr__()
    
    @staticmethod
    def load_from_file(file_path:str):
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"Can not found the configuration file \"{file_path}\"")
        with open(file_path, "r", encoding="utf-8") as f:
            hps = json.load(f)
            return HParams(**hps)
        
class BaseClassForOnnxInfer():
    @staticmethod
    def create_onnx_infer(infer_factor:Callable, onnx_model_path:str, providers:list, session_options:ort.SessionOptions = None, onnx_params:dict = None):
        if not os.path.exists(onnx_model_path):
            raise FileNotFoundError(f"Can not found the onnx model file \"{onnx_model_path}\"")
        session = ort.InferenceSession(onnx_model_path, sess_options=BaseClassForOnnxInfer.adjust_onnx_session_options(session_options), providers=providers, **(onnx_params or {}))
        fn = infer_factor(session)
        fn.__session = session
        return fn

    @staticmethod
    def get_def_onnx_session_options():
        session_options = ort.SessionOptions()
        session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
        return session_options
    
    @staticmethod
    def adjust_onnx_session_options(session_options:ort.SessionOptions = None):
        return session_options or BaseClassForOnnxInfer.get_def_onnx_session_options()
    
class OpenVoiceToneClone_ONNXRKNN(BaseClassForOnnxInfer):

    PreferredProviders = ['CPUExecutionProvider']

    def __init__(self, model_path, execution_provider:str = None, verbose:bool = False, onnx_session_options:ort.SessionOptions = None, onnx_params:dict = None, target_length:int = 1024):
        '''
        Create the instance of the tone cloner

        Args:
            model_path (str): The path of the folder which contains the model
            execution_provider (str): The provider that onnxruntime used. Such as CPUExecutionProvider, CUDAExecutionProvider, etc. Or you can use CPU, CUDA as short one. If it is None, the constructor will choose a best one automaticlly
            verbose (bool): Set True to show more detail informations when working
            onnx_session_options (onnxruntime.SessionOptions): The custom options for onnx session
            onnx_params (dict): Other parameters you want to pass to the onnxruntime.InferenceSession constructor
            target_length (int): The target length for padding/truncating spectrogram, defaults to 1024

        Returns:
            OpenVoiceToneClone_ONNX: The instance of the tone cloner
        '''
        self.__verbose = verbose
        self.__target_length = target_length

        if verbose:
            print("Loading the configuration...")
        config_path = os.path.join(model_path, "configuration.json")
        self.__hparams = HParams.load_from_file(config_path)

        execution_provider = f"{execution_provider}ExecutionProvider" if (execution_provider is not None) and (not execution_provider.endswith("ExecutionProvider")) else execution_provider
        available_providers = ort.get_available_providers()
        # self.__execution_providers = [execution_provider if execution_provider in available_providers else next((provider for provider in MeloTTS_ONNX.PreferredProviders if provider in available_providers), 'CPUExecutionProvider')]
        self.__execution_providers = ['CPUExecutionProvider']
        if verbose:
            print("Creating onnx session for tone color extractor...")
        def se_infer_factor(session):
            return lambda **kwargs: session.run(None, kwargs)[0]
        self.__se_infer = self.create_onnx_infer(se_infer_factor, os.path.join(model_path, "tone_color_extract_model.onnx"), self.__execution_providers, onnx_session_options, onnx_params)

        if verbose:
            print("Creating RKNNLite session for tone clone ...")
        # 初始化RKNNLite
        self.__tc_rknn = RKNNLite(verbose=verbose)
        # 加载RKNN模型
        ret = self.__tc_rknn.load_rknn(os.path.join(model_path, "tone_clone_model.rknn"))
        if ret != 0:
            raise RuntimeError("Failed to load RKNN model")
        # 初始化运行时
        ret = self.__tc_rknn.init_runtime()
        if ret != 0:
            raise RuntimeError("Failed to init RKNN runtime")

    def __del__(self):
        """释放RKNN资源"""
        if hasattr(self, '_OpenVoiceToneClone_ONNXRKNN__tc_rknn'):
            self.__tc_rknn.release()

    hann_window = {}

    def __spectrogram_numpy(self, y, n_fft, sampling_rate, hop_size, win_size, onesided=True):
        if self.__verbose:
            if np.min(y) < -1.1:
                print("min value is ", np.min(y))
            if np.max(y) > 1.1:
                print("max value is ", np.max(y))

        # 填充
        y = np.pad(
            y,
            int((n_fft - hop_size) / 2),
            mode="reflect",
        )

        # 生成汉宁窗
        win_key = f"{str(y.dtype)}-{win_size}"
        if True or win_key not in hann_window:
            OpenVoiceToneClone_ONNXRKNN.hann_window[win_key] = np.hanning(win_size + 1)[:-1].astype(y.dtype)
        window = OpenVoiceToneClone_ONNXRKNN.hann_window[win_key]
        
        # 短时傅里叶变换
        y_len = y.shape[0]
        win_len = window.shape[0]
        count = int((y_len - win_len) // hop_size) + 1
        spec = np.empty((count, int(win_len / 2) + 1 if onesided else (int(win_len / 2) + 1) * 2, 2))
        start = 0
        end = start + win_len
        idx = 0
        while end <= y_len:
            segment = y[start:end]
            frame = segment * window
            step_result = np.fft.rfft(frame) if onesided else np.fft.fft(frame)
            spec[idx] = np.column_stack((step_result.real, step_result.imag))
            start = start + hop_size
            end = start + win_len
            idx += 1

        # 合并实部虚部
        spec = np.sqrt(np.sum(np.square(spec), axis=-1) + 1e-6)

        return np.array([spec], dtype=np.float32)
    
    def extract_tone_color(self, audio:np.array):
        '''
        Extract the tone color from an audio

        Args:
            audio (numpy.array): The data of the audio

        Returns:
            numpy.array: The tone color vector
        '''
        hps = self.__hparams
        y = self.to_mono(audio.astype(np.float32))
        spec = self.__spectrogram_numpy(y, hps.data.filter_length,
                                    hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
                                    )
            
        if self.__verbose:
            print("spec shape", spec.shape)
        return self.__se_infer(input=spec).reshape(1,256,1)
    
    def mix_tone_color(self, colors:list):
        '''
        Mix multi tone colors to a single one

        Args:
            color (list[numpy.array]): The list of the tone colors you want to mix. Each element should be the result of extract_tone_color.

        Returns:
            numpy.array: The tone color vector
        '''
        return np.stack(colors).mean(axis=0)
    
    def tone_clone(self, audio:np.array, target_tone_color:np.array, tau=0.3):
        '''
        Clone the tone

        Args:
            audio (numpy.array): The data of the audio that will be changed the tone
            target_tone_color (numpy.array): The tone color that you want to clone. It should be the result of the extract_tone_color or mix_tone_color.
            tau (float):

        Returns:
            numpy.array: The dest audio
        '''
        assert (target_tone_color.shape == (1,256,1)), "The target tone color must be an array with shape (1,256,1)"
        hps = self.__hparams
        src = self.to_mono(audio.astype(np.float32))
        src = self.__spectrogram_numpy(src, hps.data.filter_length,
                                      hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
                                      )
        src_tone = self.__se_infer(input=src).reshape(1,256,1)
        
        src = np.transpose(src, (0, 2, 1))
        # 记录原始长度
        original_length = src.shape[2]
        
        # Pad或截断到固定长度
        if original_length > self.__target_length:
            if self.__verbose:
                print(f"Input length {original_length} exceeds target length {self.__target_length}, truncating...")
            src = src[:, :, :self.__target_length]
        elif original_length < self.__target_length:
            if self.__verbose:
                print(f"Input length {original_length} is less than target length {self.__target_length}, padding...")
            pad_width = ((0, 0), (0, 0), (0, self.__target_length - original_length))
            src = np.pad(src, pad_width, mode='constant', constant_values=0)
            
        src_length = np.array([self.__target_length], dtype=np.int64)  # 使用固定长度
        
        if self.__verbose:
            print("src shape", src.shape)
            print("src_length shape", src_length.shape)
            print("src_tone shape", src_tone.shape)
            print("target_tone_color shape", target_tone_color.shape)
            print("tau", tau)

        # 准备RKNNLite的输入
        inputs = [
            src,
            src_length, 
            src_tone,
            target_tone_color,
            np.array([tau], dtype=np.float32)
        ]
        
        # 使用RKNNLite进行推理
        outputs = self.__tc_rknn.inference(inputs=inputs)
        res = outputs[0][0, 0]  # 获取第一个输出的第一个样本
        
        generated_multiplier = 262144 / 1024
        # 如果原始输入较短,则截取掉padding部分
        if original_length < self.__target_length:
            res = res[:int(original_length * generated_multiplier)]
        
        if self.__verbose:
            print("res shape", res.shape)
        return res
    
    def to_mono(self, audio:np.array):
        '''
        Change the audio to be a mono audio

        Args:
            audio (numpy.array): The source audio

        Returns:
            numpy.array: The mono audio data
        '''
        return np.mean(audio, axis=1) if len(audio.shape) > 1 else audio

    def resample(self, audio:np.array, original_rate:int):
        '''
        Resample the audio to match the model. It is used for changing the sample rate of the audio.

        Args:
            audio (numpy.array): The source audio you want to resample.
            original_rate (int): The original sample rate of the source audio

        Returns:
            numpy.array: The dest data of the audio after resample
        '''
        audio = self.to_mono(audio)
        target_rate = self.__hparams.data.sampling_rate
        duration = audio.shape[0] / original_rate
        target_length = int(duration * target_rate)
        time_original = np.linspace(0, duration, num=audio.shape[0])
        time_target = np.linspace(0, duration, num=target_length)
        resampled_data = np.interp(time_target, time_original, audio)
        return resampled_data
    
    @property
    def sample_rate(self):
        '''
        The sample rate of the tone cloning result 
        '''
        return self.__hparams.data.sampling_rate
    

tc = OpenVoiceToneClone_ONNXRKNN(".",verbose=True)
import soundfile

tgt = soundfile.read("target.wav", dtype='float32')
tgt = tc.resample(tgt[0], tgt[1])

# 计时extract_tone_color
start_time = time.time()
tgt_tone_color = tc.extract_tone_color(tgt)
extract_time = time.time() - start_time
print(f"提取音色特征耗时: {extract_time:.2f}秒")

src = soundfile.read("src2.wav", dtype='float32')
src = tc.resample(src[0], src[1])

# 计时tone_clone
start_time = time.time()
result = tc.tone_clone(src, tgt_tone_color)
clone_time = time.time() - start_time
print(f"克隆音色耗时: {clone_time:.2f}秒")

soundfile.write("result.wav", result, tc.sample_rate)