File size: 2,144 Bytes
a4bd2c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""Convert DeepFormants LPC MLP estimator to ONNX fp32/fp16/int8.

Source: pytorchFormants/Estimator/LPC_NN_scaledLoss.pt (350 -> 1024 -> 512 -> 256 -> 4, sigmoid hidden).
"""
from pathlib import Path
import torch
import torch.nn as nn
import onnx
from onnxconverter_common import float16
from onnxruntime.quantization import quantize_dynamic, QuantType

ROOT = Path(__file__).resolve().parents[1]
REPO = ROOT.parent / "DeepFormants"
OUT = ROOT / "lpc_estimator"
OUT.mkdir(parents=True, exist_ok=True)

CKPT = REPO / "pytorchFormants" / "Estimator" / "LPC_NN_scaledLoss.pt"


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.Dense1 = nn.Linear(350, 1024)
        self.Dense2 = nn.Linear(1024, 512)
        self.Dense3 = nn.Linear(512, 256)
        self.out = nn.Linear(256, 4)

    def forward(self, x):
        x = torch.sigmoid(self.Dense1(x))
        x = torch.sigmoid(self.Dense2(x))
        x = torch.sigmoid(self.Dense3(x))
        return self.out(x)


def main():
    model = Net()
    sd = torch.load(CKPT, map_location="cpu", weights_only=True)
    model.load_state_dict(sd)
    model.eval()

    dummy = torch.randn(1, 350)
    fp32_path = OUT / "model.onnx"
    torch.onnx.export(
        model, dummy, fp32_path.as_posix(),
        opset_version=17,
        input_names=["input"], output_names=["formants"],
        dynamic_axes={"input": {0: "batch"}, "formants": {0: "batch"}},
        dynamo=False,
    )
    onnx.checker.check_model(onnx.load(fp32_path.as_posix()))
    print(f"fp32 -> {fp32_path} ({fp32_path.stat().st_size/1e6:.2f} MB)")

    m_fp32 = onnx.load(fp32_path.as_posix())
    m_fp16 = float16.convert_float_to_float16(m_fp32, keep_io_types=False)
    fp16_path = OUT / "model_fp16.onnx"
    onnx.save(m_fp16, fp16_path.as_posix())
    print(f"fp16 -> {fp16_path} ({fp16_path.stat().st_size/1e6:.2f} MB)")

    int8_path = OUT / "model_int8.onnx"
    quantize_dynamic(
        fp32_path.as_posix(), int8_path.as_posix(),
        weight_type=QuantType.QInt8,
    )
    print(f"int8 -> {int8_path} ({int8_path.stat().st_size/1e6:.2f} MB)")


if __name__ == "__main__":
    main()