#!/usr/bin/env python3 """Export Piper checkpoint using JIT tracing""" import sys import torch from pathlib import Path sys.path.insert(0, str(Path("/root/piper_msa/piper1-gpl/src"))) from piper.train.vits.lightning import VitsModel def main(): checkpoint_path = "/root/piper_msa/piper1-gpl/lightning_logs/version_1/checkpoints/epoch=54-step=143440.ckpt" output_path = "/root/piper_msa/output/saudi_msa.onnx" print(f"Loading checkpoint: {checkpoint_path}") model = VitsModel.load_from_checkpoint(checkpoint_path, map_location="cpu") model_g = model.model_g # Inference only model_g.eval() with torch.no_grad(): model_g.dec.remove_weight_norm() def infer_forward(text, text_lengths, scales, sid=None): noise_scale = scales[0] length_scale = scales[1] noise_scale_w = scales[2] audio = model_g.infer( text, text_lengths, noise_scale=noise_scale, length_scale=length_scale, noise_scale_w=noise_scale_w, sid=sid, )[0].unsqueeze(1) return audio model_g.forward = infer_forward num_symbols = model_g.n_vocab num_speakers = model_g.n_speakers dummy_input_length = 50 sequences = torch.randint( low=0, high=num_symbols, size=(1, dummy_input_length), dtype=torch.long ) sequence_lengths = torch.LongTensor([sequences.size(1)]) sid = None if num_speakers > 1: sid = torch.LongTensor([0]) scales = torch.FloatTensor([0.667, 1.0, 0.8]) dummy_input = (sequences, sequence_lengths, scales, sid) print(f"Exporting to ONNX using JIT: {output_path}") # Use JIT tracing with legacy exporter with torch.no_grad(): torch.onnx.export( model=model_g, args=dummy_input, f=output_path, verbose=False, opset_version=15, input_names=["input", "input_lengths", "scales", "sid"], output_names=["output"], dynamic_axes={ "input": {0: "batch_size", 1: "phonemes"}, "input_lengths": {0: "batch_size"}, "output": {0: "batch_size", 2: "time"}, }, export_params=True, do_constant_folding=True, # Use legacy JIT-based exporter dynamo=False, ) print(f"✓ Model exported successfully to: {output_path}") print(f"\nTo test the model:") print(f" echo 'مرحبا بك' | piper --model {output_path} --config /root/piper_msa/output/config.json --output_file test.wav") if __name__ == "__main__": main()