| import argparse |
| import os |
| import torch |
| import onnx |
| import onnxruntime as ort |
| import sounddevice as sd |
|
|
| from indicvoice import IndicModel, IndicPipeline |
| from indicvoice.model import IndicModelForONNX |
|
|
| def export_onnx(model, output): |
| onnx_file = output + "/" + "indicvoice.onnx" |
|
|
| input_ids = torch.randint(1, 100, (48,)).numpy() |
| input_ids = torch.LongTensor([[0, *input_ids, 0]]) |
| style = torch.randn(1, 256) |
| speed = torch.randint(1, 10, (1,)).int() |
|
|
| torch.onnx.export( |
| model, |
| args = (input_ids, style, speed), |
| f = onnx_file, |
| export_params = True, |
| verbose = True, |
| input_names = [ 'input_ids', 'style', 'speed' ], |
| output_names = [ 'waveform', 'duration' ], |
| opset_version = 17, |
| dynamic_axes = { |
| 'input_ids': {0: "batch_size", 1: 'input_ids_len' }, |
| 'style': {0: "batch_size"}, |
| "speed": {0: "batch_size"} |
| }, |
| do_constant_folding = True, |
| ) |
|
|
| print('export indicvoice.onnx ok!') |
|
|
| onnx_model = onnx.load(onnx_file) |
| onnx.checker.check_model(onnx_model) |
| print('onnx check ok!') |
|
|
| def load_input_ids(pipeline, text): |
| if pipeline.lang_code in 'ab': |
| _, tokens = pipeline.g2p(text) |
| for gs, ps, tks in pipeline.en_tokenize(tokens): |
| if not ps: |
| continue |
| else: |
| ps, _ = pipeline.g2p(text) |
|
|
| if len(ps) > 510: |
| ps = ps[:510] |
|
|
| input_ids = list(filter(lambda i: i is not None, map(lambda p: pipeline.model.vocab.get(p), ps))) |
| print(f"text: {text} -> phonemes: {ps} -> input_ids: {input_ids}") |
| input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(pipeline.model.device) |
| return ps, input_ids |
|
|
| def load_voice(pipeline, voice, phonemes): |
| pack = pipeline.load_voice(voice).to('cpu') |
| return pack[len(phonemes) - 1] |
|
|
| def load_sample(model): |
| pipeline = IndicPipeline(lang_code='a', model=model.kmodel, device='cpu') |
| text = ''' |
| In today's fast-paced tech world, building software applications has never been easier — thanks to AI-powered coding assistants.' |
| ''' |
| text = ''' |
| The sky above the port was the color of television, tuned to a dead channel. |
| ''' |
| voice = 'checkpoints/voices/af_heart.pt' |
|
|
| pipeline = IndicPipeline(lang_code='z', model=model.kmodel, device='cpu') |
| text = ''' |
| 2月15日晚,猫眼专业版数据显示,截至发稿,《哪吒之魔童闹海》(或称《哪吒2》)今日票房已达7.8亿元,累计票房(含预售)超过114亿元。 |
| ''' |
| voice = 'checkpoints/voices/zf_xiaoxiao.pt' |
|
|
| phonemes, input_ids = load_input_ids(pipeline, text) |
| style = load_voice(pipeline, voice, phonemes) |
| speed = torch.IntTensor([1]) |
|
|
| return input_ids, style, speed |
|
|
| def inference_onnx(model, output): |
| onnx_file = output + "/" + "indicvoice.onnx" |
| session = ort.InferenceSession(onnx_file) |
|
|
| input_ids, style, speed = load_sample(model) |
|
|
| outputs = session.run(None, { |
| 'input_ids': input_ids.numpy(), |
| 'style': style.numpy(), |
| 'speed': speed.numpy(), |
| }) |
|
|
| output = torch.from_numpy(outputs[0]) |
| print(f'output: {output.shape}') |
| print(output) |
|
|
| audio = output.numpy() |
| sd.play(audio, 24000) |
| sd.wait() |
|
|
| def check_model(model): |
| input_ids, style, speed = load_sample(model) |
| output, duration = model(input_ids, style, speed) |
|
|
| print(f'output: {output.shape}') |
| print(f'duration: {duration.shape}') |
| print(output) |
|
|
| audio = output.numpy() |
| sd.play(audio, 24000) |
| sd.wait() |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser("Export IndicVoice Model to ONNX", add_help=True) |
| parser.add_argument("--inference", "-t", help="test indicvoice.onnx model", action="store_true") |
| parser.add_argument("--check", "-m", help="check indicvoice model", action="store_true") |
| parser.add_argument( |
| "--config_file", "-c", type=str, default="checkpoints/config.json", help="path to config file" |
| ) |
| parser.add_argument( |
| "--checkpoint_path", "-p", type=str, default="checkpoints/indicvoice-v1_0.pth", help="path to checkpoint file" |
| ) |
| parser.add_argument( |
| "--output_dir", "-o", type=str, default="onnx", help="output directory" |
| ) |
|
|
| args = parser.parse_args() |
|
|
| |
| config_file = args.config_file |
| checkpoint_path = args.checkpoint_path |
| output_dir = args.output_dir |
| |
| |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| kmodel = IndicModel(config=config_file, model=checkpoint_path, disable_complex=True) |
| model = IndicModelForONNX(kmodel).eval() |
|
|
| if args.inference: |
| inference_onnx(model, output_dir) |
| elif args.check: |
| check_model(model) |
| else: |
| export_onnx(model, output_dir) |
|
|