"""Convert reconstructed Torch7-origin tracker to ONNX fp32/fp16/int8.""" from pathlib import Path import torch import torch.nn as nn import onnx from onnxconverter_common import float16 from onnxruntime.quantization import quantize_dynamic, QuantType ROOT = Path(__file__).resolve().parents[1] OUT = ROOT / "lpc_tracker_torch7" CKPT = OUT / "reconstructed.pt" class LSTMTracker(nn.Module): def __init__(self): super().__init__() self.lstm1 = nn.LSTM(350, 512, batch_first=True) self.lstm2 = nn.LSTM(512, 256, batch_first=True) self.fc = nn.Linear(256, 4) def forward(self, x): x, _ = self.lstm1(x) x, _ = self.lstm2(x) return self.fc(x) def main(): model = LSTMTracker() model.load_state_dict(torch.load(CKPT, map_location="cpu", weights_only=True)) model.eval() dummy = torch.randn(1, 20, 350) fp32_path = OUT / "model.onnx" torch.onnx.export( model, dummy, fp32_path.as_posix(), opset_version=17, input_names=["input"], output_names=["formants"], dynamic_axes={"input": {0: "batch", 1: "time"}, "formants": {0: "batch", 1: "time"}}, dynamo=False, ) onnx.checker.check_model(onnx.load(fp32_path.as_posix())) print(f"fp32 -> {fp32_path} ({fp32_path.stat().st_size/1e6:.2f} MB)") m_fp16 = float16.convert_float_to_float16(onnx.load(fp32_path.as_posix()), keep_io_types=False) fp16_path = OUT / "model_fp16.onnx" onnx.save(m_fp16, fp16_path.as_posix()) print(f"fp16 -> {fp16_path} ({fp16_path.stat().st_size/1e6:.2f} MB)") int8_path = OUT / "model_int8.onnx" quantize_dynamic(fp32_path.as_posix(), int8_path.as_posix(), weight_type=QuantType.QInt8) print(f"int8 -> {int8_path} ({int8_path.stat().st_size/1e6:.2f} MB)") if __name__ == "__main__": main()