DeepFormants / scripts /convert_tracker.py
FredrikKarlssonSpeech's picture
Upload DeepFormants ONNX (fp32/fp16/int8) for LPC estimator + LSTM tracker
a4bd2c0 verified
"""Convert DeepFormants LPC RNN tracker to ONNX fp32/fp16/int8.
Source: pytorchFormants/Tracker/LPC_RNN.pt
Arch: LSTM(350->512) -> LSTM(512->256) -> Linear(256, 4). Input (B, T, 350).
"""
from pathlib import Path
import torch
import torch.nn as nn
import onnx
from onnxconverter_common import float16
from onnxruntime.quantization import quantize_dynamic, QuantType
ROOT = Path(__file__).resolve().parents[1]
REPO = ROOT.parent / "DeepFormants"
OUT = ROOT / "lpc_tracker"
OUT.mkdir(parents=True, exist_ok=True)
CKPT = REPO / "pytorchFormants" / "Tracker" / "LPC_RNN.pt"
class LSTMTracker(nn.Module):
def __init__(self):
super().__init__()
self.lstm1 = nn.LSTM(input_size=350, hidden_size=512, batch_first=True)
self.lstm2 = nn.LSTM(input_size=512, hidden_size=256, batch_first=True)
self.fc = nn.Linear(256, 4)
def forward(self, x):
x, _ = self.lstm1(x)
x, _ = self.lstm2(x)
return self.fc(x)
def main():
model = LSTMTracker()
sd = torch.load(CKPT, map_location="cpu", weights_only=True)
model.load_state_dict(sd)
model.eval()
dummy = torch.randn(1, 20, 350)
fp32_path = OUT / "model.onnx"
torch.onnx.export(
model, dummy, fp32_path.as_posix(),
opset_version=17,
input_names=["input"], output_names=["formants"],
dynamic_axes={"input": {0: "batch", 1: "time"},
"formants": {0: "batch", 1: "time"}},
dynamo=False,
)
onnx.checker.check_model(onnx.load(fp32_path.as_posix()))
print(f"fp32 -> {fp32_path} ({fp32_path.stat().st_size/1e6:.2f} MB)")
m_fp32 = onnx.load(fp32_path.as_posix())
m_fp16 = float16.convert_float_to_float16(m_fp32, keep_io_types=False)
fp16_path = OUT / "model_fp16.onnx"
onnx.save(m_fp16, fp16_path.as_posix())
print(f"fp16 -> {fp16_path} ({fp16_path.stat().st_size/1e6:.2f} MB)")
int8_path = OUT / "model_int8.onnx"
quantize_dynamic(
fp32_path.as_posix(), int8_path.as_posix(),
weight_type=QuantType.QInt8,
)
print(f"int8 -> {int8_path} ({int8_path.stat().st_size/1e6:.2f} MB)")
if __name__ == "__main__":
main()