"""Convert reconstructed Torch7-origin PyTorch state_dict to ONNX fp32/fp16/int8.""" from pathlib import Path import torch import torch.nn as nn import onnx from onnxconverter_common import float16 from onnxruntime.quantization import quantize_dynamic, QuantType ROOT = Path(__file__).resolve().parents[1] OUT = ROOT / "lpc_estimator_torch7" OUT.mkdir(parents=True, exist_ok=True) CKPT = OUT / "reconstructed.pt" class Net(nn.Module): def __init__(self): super().__init__() self.Dense1 = nn.Linear(350, 1024) self.Dense2 = nn.Linear(1024, 512) self.Dense3 = nn.Linear(512, 256) self.out = nn.Linear(256, 4) def forward(self, x): x = torch.sigmoid(self.Dense1(x)) x = torch.sigmoid(self.Dense2(x)) x = torch.sigmoid(self.Dense3(x)) return self.out(x) def main(): model = Net() model.load_state_dict(torch.load(CKPT, map_location="cpu", weights_only=True)) model.eval() dummy = torch.randn(1, 350) fp32_path = OUT / "model.onnx" torch.onnx.export( model, dummy, fp32_path.as_posix(), opset_version=17, input_names=["input"], output_names=["formants"], dynamic_axes={"input": {0: "batch"}, "formants": {0: "batch"}}, dynamo=False, ) onnx.checker.check_model(onnx.load(fp32_path.as_posix())) print(f"fp32 -> {fp32_path} ({fp32_path.stat().st_size/1e6:.2f} MB)") m_fp16 = float16.convert_float_to_float16(onnx.load(fp32_path.as_posix()), keep_io_types=False) fp16_path = OUT / "model_fp16.onnx" onnx.save(m_fp16, fp16_path.as_posix()) print(f"fp16 -> {fp16_path} ({fp16_path.stat().st_size/1e6:.2f} MB)") int8_path = OUT / "model_int8.onnx" quantize_dynamic(fp32_path.as_posix(), int8_path.as_posix(), weight_type=QuantType.QInt8) print(f"int8 -> {int8_path} ({int8_path.stat().st_size/1e6:.2f} MB)") if __name__ == "__main__": main()