| | |
| | """Export Piper checkpoint using JIT tracing""" |
| |
|
| | import sys |
| | import torch |
| | from pathlib import Path |
| |
|
| | sys.path.insert(0, str(Path("/root/piper_msa/piper1-gpl/src"))) |
| |
|
| | from piper.train.vits.lightning import VitsModel |
| |
|
| | def main(): |
| | checkpoint_path = "/root/piper_msa/piper1-gpl/lightning_logs/version_1/checkpoints/epoch=54-step=143440.ckpt" |
| | output_path = "/root/piper_msa/output/saudi_msa.onnx" |
| | |
| | print(f"Loading checkpoint: {checkpoint_path}") |
| | model = VitsModel.load_from_checkpoint(checkpoint_path, map_location="cpu") |
| | model_g = model.model_g |
| | |
| | |
| | model_g.eval() |
| | |
| | with torch.no_grad(): |
| | model_g.dec.remove_weight_norm() |
| | |
| | def infer_forward(text, text_lengths, scales, sid=None): |
| | noise_scale = scales[0] |
| | length_scale = scales[1] |
| | noise_scale_w = scales[2] |
| | audio = model_g.infer( |
| | text, |
| | text_lengths, |
| | noise_scale=noise_scale, |
| | length_scale=length_scale, |
| | noise_scale_w=noise_scale_w, |
| | sid=sid, |
| | )[0].unsqueeze(1) |
| | return audio |
| | |
| | model_g.forward = infer_forward |
| | |
| | num_symbols = model_g.n_vocab |
| | num_speakers = model_g.n_speakers |
| | |
| | dummy_input_length = 50 |
| | sequences = torch.randint( |
| | low=0, high=num_symbols, size=(1, dummy_input_length), dtype=torch.long |
| | ) |
| | sequence_lengths = torch.LongTensor([sequences.size(1)]) |
| | |
| | sid = None |
| | if num_speakers > 1: |
| | sid = torch.LongTensor([0]) |
| | |
| | scales = torch.FloatTensor([0.667, 1.0, 0.8]) |
| | dummy_input = (sequences, sequence_lengths, scales, sid) |
| | |
| | print(f"Exporting to ONNX using JIT: {output_path}") |
| | |
| | |
| | with torch.no_grad(): |
| | torch.onnx.export( |
| | model=model_g, |
| | args=dummy_input, |
| | f=output_path, |
| | verbose=False, |
| | opset_version=15, |
| | input_names=["input", "input_lengths", "scales", "sid"], |
| | output_names=["output"], |
| | dynamic_axes={ |
| | "input": {0: "batch_size", 1: "phonemes"}, |
| | "input_lengths": {0: "batch_size"}, |
| | "output": {0: "batch_size", 2: "time"}, |
| | }, |
| | export_params=True, |
| | do_constant_folding=True, |
| | |
| | dynamo=False, |
| | ) |
| | |
| | print(f"✓ Model exported successfully to: {output_path}") |
| | print(f"\nTo test the model:") |
| | print(f" echo 'مرحبا بك' | piper --model {output_path} --config /root/piper_msa/output/config.json --output_file test.wav") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|