DeepFormants / scripts /convert_lpc.py
FredrikKarlssonSpeech's picture
Upload DeepFormants ONNX (fp32/fp16/int8) for LPC estimator + LSTM tracker
a4bd2c0 verified
"""Convert DeepFormants LPC MLP estimator to ONNX fp32/fp16/int8.
Source: pytorchFormants/Estimator/LPC_NN_scaledLoss.pt (350 -> 1024 -> 512 -> 256 -> 4, sigmoid hidden).
"""
from pathlib import Path
import torch
import torch.nn as nn
import onnx
from onnxconverter_common import float16
from onnxruntime.quantization import quantize_dynamic, QuantType
ROOT = Path(__file__).resolve().parents[1]
REPO = ROOT.parent / "DeepFormants"
OUT = ROOT / "lpc_estimator"
OUT.mkdir(parents=True, exist_ok=True)
CKPT = REPO / "pytorchFormants" / "Estimator" / "LPC_NN_scaledLoss.pt"
class Net(nn.Module):
def __init__(self):
super().__init__()
self.Dense1 = nn.Linear(350, 1024)
self.Dense2 = nn.Linear(1024, 512)
self.Dense3 = nn.Linear(512, 256)
self.out = nn.Linear(256, 4)
def forward(self, x):
x = torch.sigmoid(self.Dense1(x))
x = torch.sigmoid(self.Dense2(x))
x = torch.sigmoid(self.Dense3(x))
return self.out(x)
def main():
model = Net()
sd = torch.load(CKPT, map_location="cpu", weights_only=True)
model.load_state_dict(sd)
model.eval()
dummy = torch.randn(1, 350)
fp32_path = OUT / "model.onnx"
torch.onnx.export(
model, dummy, fp32_path.as_posix(),
opset_version=17,
input_names=["input"], output_names=["formants"],
dynamic_axes={"input": {0: "batch"}, "formants": {0: "batch"}},
dynamo=False,
)
onnx.checker.check_model(onnx.load(fp32_path.as_posix()))
print(f"fp32 -> {fp32_path} ({fp32_path.stat().st_size/1e6:.2f} MB)")
m_fp32 = onnx.load(fp32_path.as_posix())
m_fp16 = float16.convert_float_to_float16(m_fp32, keep_io_types=False)
fp16_path = OUT / "model_fp16.onnx"
onnx.save(m_fp16, fp16_path.as_posix())
print(f"fp16 -> {fp16_path} ({fp16_path.stat().st_size/1e6:.2f} MB)")
int8_path = OUT / "model_int8.onnx"
quantize_dynamic(
fp32_path.as_posix(), int8_path.as_posix(),
weight_type=QuantType.QInt8,
)
print(f"int8 -> {int8_path} ({int8_path.stat().st_size/1e6:.2f} MB)")
if __name__ == "__main__":
main()