| import torch |
| import torch.nn as nn |
| from openvoice.api import ToneColorConverter |
| from openvoice.models import SynthesizerTrn |
| import os |
|
|
| os.chdir(os.path.dirname(os.path.abspath(__file__))) |
|
|
| class ToneColorExtractWrapper(nn.Module): |
| def __init__(self, model): |
| super().__init__() |
| self.model = model |
| |
| def forward(self, audio): |
| |
| |
| audio = audio.contiguous() |
| |
| g = self.model.ref_enc(audio) |
| |
| |
| return g |
|
|
| class ToneCloneWrapper(nn.Module): |
| def __init__(self, model): |
| super().__init__() |
| self.model = model |
| |
| def forward(self, audio, audio_lengths, src_tone, dest_tone, tau): |
| |
| audio = audio.contiguous() |
| src_tone = src_tone.contiguous() |
| dest_tone = dest_tone.contiguous() |
| |
| |
| o_hat, _, _ = self.model.voice_conversion( |
| audio, |
| audio_lengths, |
| sid_src=src_tone, |
| sid_tgt=dest_tone, |
| tau=tau[0] |
| ) |
| return o_hat |
|
|
| def export_models(ckpt_path, output_dir, target_audio_lens, source_audio_lens): |
| """ |
| 导出音色提取和克隆模型为ONNX格式 |
| |
| Args: |
| ckpt_path: 模型检查点路径 |
| output_dir: 输出目录 |
| target_audio_lens: 目标音频长度列表 |
| source_audio_lens: 源音频长度列表 |
| """ |
| |
| |
| device = "cpu" |
| converter = ToneColorConverter(f'{ckpt_path}/config.json', device=device) |
| converter.load_ckpt(f'{ckpt_path}/checkpoint.pth') |
| |
| |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| extract_wrapper = ToneColorExtractWrapper(converter.model) |
| extract_wrapper.eval() |
| |
| for source_len in source_audio_lens: |
| dummy_input = torch.randn(1, source_len, 513).contiguous() |
| output_path = f"{output_dir}/tone_color_extract_model.onnx" |
| |
| torch.onnx.export( |
| extract_wrapper, |
| dummy_input, |
| output_path, |
| input_names=['input'], |
| output_names=['tone_embedding'], |
| dynamic_axes={ |
| 'input': {1: 'source_audio_len'}, |
| }, |
| opset_version=11, |
| do_constant_folding=True, |
| verbose=True |
| ) |
| print(f"Exported tone extract model to {output_path}") |
| |
| |
| clone_wrapper = ToneCloneWrapper(converter.model) |
| clone_wrapper.eval() |
| |
| for target_len in target_audio_lens: |
| dummy_inputs = ( |
| torch.randn(1, 513, target_len).contiguous(), |
| torch.LongTensor([target_len]), |
| torch.randn(1, 256, 1).contiguous(), |
| torch.randn(1, 256, 1).contiguous(), |
| torch.FloatTensor([0.3]) |
| ) |
| |
| output_path = f"{output_dir}/tone_clone_model.onnx" |
| |
| torch.onnx.export( |
| clone_wrapper, |
| dummy_inputs, |
| output_path, |
| input_names=['audio', 'audio_length', 'src_tone', 'dest_tone', 'tau'], |
| output_names=['converted_audio'], |
| dynamic_axes={ |
| 'audio': {2: 'target_audio_len'}, |
| }, |
| opset_version=17, |
| do_constant_folding=True, |
| verbose=True |
| ) |
| print(f"Exported tone clone model to {output_path}") |
|
|
| if __name__ == "__main__": |
| |
| TARGET_AUDIO_LENS = [1024] |
| SOURCE_AUDIO_LENS = [1024] |
| |
| export_models( |
| ckpt_path="checkpoints_v2/converter", |
| output_dir="onnx_models", |
| target_audio_lens=TARGET_AUDIO_LENS, |
| source_audio_lens=SOURCE_AUDIO_LENS |
| ) |