File size: 2,169 Bytes
a4bd2c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""Convert DeepFormants LPC RNN tracker to ONNX fp32/fp16/int8.

Source: pytorchFormants/Tracker/LPC_RNN.pt
Arch: LSTM(350->512) -> LSTM(512->256) -> Linear(256, 4). Input (B, T, 350).
"""
from pathlib import Path
import torch
import torch.nn as nn
import onnx
from onnxconverter_common import float16
from onnxruntime.quantization import quantize_dynamic, QuantType

ROOT = Path(__file__).resolve().parents[1]
REPO = ROOT.parent / "DeepFormants"
OUT = ROOT / "lpc_tracker"
OUT.mkdir(parents=True, exist_ok=True)

CKPT = REPO / "pytorchFormants" / "Tracker" / "LPC_RNN.pt"


class LSTMTracker(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm1 = nn.LSTM(input_size=350, hidden_size=512, batch_first=True)
        self.lstm2 = nn.LSTM(input_size=512, hidden_size=256, batch_first=True)
        self.fc = nn.Linear(256, 4)

    def forward(self, x):
        x, _ = self.lstm1(x)
        x, _ = self.lstm2(x)
        return self.fc(x)


def main():
    model = LSTMTracker()
    sd = torch.load(CKPT, map_location="cpu", weights_only=True)
    model.load_state_dict(sd)
    model.eval()

    dummy = torch.randn(1, 20, 350)
    fp32_path = OUT / "model.onnx"
    torch.onnx.export(
        model, dummy, fp32_path.as_posix(),
        opset_version=17,
        input_names=["input"], output_names=["formants"],
        dynamic_axes={"input": {0: "batch", 1: "time"},
                      "formants": {0: "batch", 1: "time"}},
        dynamo=False,
    )
    onnx.checker.check_model(onnx.load(fp32_path.as_posix()))
    print(f"fp32 -> {fp32_path} ({fp32_path.stat().st_size/1e6:.2f} MB)")

    m_fp32 = onnx.load(fp32_path.as_posix())
    m_fp16 = float16.convert_float_to_float16(m_fp32, keep_io_types=False)
    fp16_path = OUT / "model_fp16.onnx"
    onnx.save(m_fp16, fp16_path.as_posix())
    print(f"fp16 -> {fp16_path} ({fp16_path.stat().st_size/1e6:.2f} MB)")

    int8_path = OUT / "model_int8.onnx"
    quantize_dynamic(
        fp32_path.as_posix(), int8_path.as_posix(),
        weight_type=QuantType.QInt8,
    )
    print(f"int8 -> {int8_path} ({int8_path.stat().st_size/1e6:.2f} MB)")


if __name__ == "__main__":
    main()