| import argparse |
| import time |
| import numpy as np |
| import onnx |
| from onnxsim import simplify |
| import onnxruntime as ort |
| import onnxoptimizer |
| import torch |
| from model_onnx import SynthesizerTrn |
| import utils |
| from hubert import hubert_model_onnx |
|
|
| def main(HubertExport,NetExport): |
|
|
| path = "NyaruTaffy" |
|
|
| if(HubertExport): |
| device = torch.device("cuda") |
| hubert_soft = utils.get_hubert_model() |
| test_input = torch.rand(1, 1, 16000) |
| input_names = ["source"] |
| output_names = ["embed"] |
| torch.onnx.export(hubert_soft.to(device), |
| test_input.to(device), |
| "hubert3.0.onnx", |
| dynamic_axes={ |
| "source": { |
| 2: "sample_length" |
| } |
| }, |
| verbose=False, |
| opset_version=13, |
| input_names=input_names, |
| output_names=output_names) |
| if(NetExport): |
| device = torch.device("cuda") |
| hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json") |
| SVCVITS = SynthesizerTrn( |
| hps.data.filter_length // 2 + 1, |
| hps.train.segment_size // hps.data.hop_length, |
| **hps.model) |
| _ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", SVCVITS, None) |
| _ = SVCVITS.eval().to(device) |
| for i in SVCVITS.parameters(): |
| i.requires_grad = False |
| test_hidden_unit = torch.rand(1, 50, 256) |
| test_lengths = torch.LongTensor([50]) |
| test_pitch = torch.rand(1, 50) |
| test_sid = torch.LongTensor([0]) |
| input_names = ["hidden_unit", "lengths", "pitch", "sid"] |
| output_names = ["audio", ] |
| SVCVITS.eval() |
| torch.onnx.export(SVCVITS, |
| ( |
| test_hidden_unit.to(device), |
| test_lengths.to(device), |
| test_pitch.to(device), |
| test_sid.to(device) |
| ), |
| f"checkpoints/{path}/model.onnx", |
| dynamic_axes={ |
| "hidden_unit": [0, 1], |
| "pitch": [1] |
| }, |
| do_constant_folding=False, |
| opset_version=16, |
| verbose=False, |
| input_names=input_names, |
| output_names=output_names) |
|
|
|
|
| if __name__ == '__main__': |
| main(False,True) |
|
|