| """Convert DeepFormants LPC MLP estimator to ONNX fp32/fp16/int8. |
| |
| Source: pytorchFormants/Estimator/LPC_NN_scaledLoss.pt (350 -> 1024 -> 512 -> 256 -> 4, sigmoid hidden). |
| """ |
| from pathlib import Path |
| import torch |
| import torch.nn as nn |
| import onnx |
| from onnxconverter_common import float16 |
| from onnxruntime.quantization import quantize_dynamic, QuantType |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| REPO = ROOT.parent / "DeepFormants" |
| OUT = ROOT / "lpc_estimator" |
| OUT.mkdir(parents=True, exist_ok=True) |
|
|
| CKPT = REPO / "pytorchFormants" / "Estimator" / "LPC_NN_scaledLoss.pt" |
|
|
|
|
| class Net(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.Dense1 = nn.Linear(350, 1024) |
| self.Dense2 = nn.Linear(1024, 512) |
| self.Dense3 = nn.Linear(512, 256) |
| self.out = nn.Linear(256, 4) |
|
|
| def forward(self, x): |
| x = torch.sigmoid(self.Dense1(x)) |
| x = torch.sigmoid(self.Dense2(x)) |
| x = torch.sigmoid(self.Dense3(x)) |
| return self.out(x) |
|
|
|
|
| def main(): |
| model = Net() |
| sd = torch.load(CKPT, map_location="cpu", weights_only=True) |
| model.load_state_dict(sd) |
| model.eval() |
|
|
| dummy = torch.randn(1, 350) |
| fp32_path = OUT / "model.onnx" |
| torch.onnx.export( |
| model, dummy, fp32_path.as_posix(), |
| opset_version=17, |
| input_names=["input"], output_names=["formants"], |
| dynamic_axes={"input": {0: "batch"}, "formants": {0: "batch"}}, |
| dynamo=False, |
| ) |
| onnx.checker.check_model(onnx.load(fp32_path.as_posix())) |
| print(f"fp32 -> {fp32_path} ({fp32_path.stat().st_size/1e6:.2f} MB)") |
|
|
| m_fp32 = onnx.load(fp32_path.as_posix()) |
| m_fp16 = float16.convert_float_to_float16(m_fp32, keep_io_types=False) |
| fp16_path = OUT / "model_fp16.onnx" |
| onnx.save(m_fp16, fp16_path.as_posix()) |
| print(f"fp16 -> {fp16_path} ({fp16_path.stat().st_size/1e6:.2f} MB)") |
|
|
| int8_path = OUT / "model_int8.onnx" |
| quantize_dynamic( |
| fp32_path.as_posix(), int8_path.as_posix(), |
| weight_type=QuantType.QInt8, |
| ) |
| print(f"int8 -> {int8_path} ({int8_path.stat().st_size/1e6:.2f} MB)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|