| """Convert DeepFormants LPC RNN tracker to ONNX fp32/fp16/int8. |
| |
| Source: pytorchFormants/Tracker/LPC_RNN.pt |
| Arch: LSTM(350->512) -> LSTM(512->256) -> Linear(256, 4). Input (B, T, 350). |
| """ |
| from pathlib import Path |
| import torch |
| import torch.nn as nn |
| import onnx |
| from onnxconverter_common import float16 |
| from onnxruntime.quantization import quantize_dynamic, QuantType |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| REPO = ROOT.parent / "DeepFormants" |
| OUT = ROOT / "lpc_tracker" |
| OUT.mkdir(parents=True, exist_ok=True) |
|
|
| CKPT = REPO / "pytorchFormants" / "Tracker" / "LPC_RNN.pt" |
|
|
|
|
| class LSTMTracker(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.lstm1 = nn.LSTM(input_size=350, hidden_size=512, batch_first=True) |
| self.lstm2 = nn.LSTM(input_size=512, hidden_size=256, batch_first=True) |
| self.fc = nn.Linear(256, 4) |
|
|
| def forward(self, x): |
| x, _ = self.lstm1(x) |
| x, _ = self.lstm2(x) |
| return self.fc(x) |
|
|
|
|
| def main(): |
| model = LSTMTracker() |
| sd = torch.load(CKPT, map_location="cpu", weights_only=True) |
| model.load_state_dict(sd) |
| model.eval() |
|
|
| dummy = torch.randn(1, 20, 350) |
| fp32_path = OUT / "model.onnx" |
| torch.onnx.export( |
| model, dummy, fp32_path.as_posix(), |
| opset_version=17, |
| input_names=["input"], output_names=["formants"], |
| dynamic_axes={"input": {0: "batch", 1: "time"}, |
| "formants": {0: "batch", 1: "time"}}, |
| dynamo=False, |
| ) |
| onnx.checker.check_model(onnx.load(fp32_path.as_posix())) |
| print(f"fp32 -> {fp32_path} ({fp32_path.stat().st_size/1e6:.2f} MB)") |
|
|
| m_fp32 = onnx.load(fp32_path.as_posix()) |
| m_fp16 = float16.convert_float_to_float16(m_fp32, keep_io_types=False) |
| fp16_path = OUT / "model_fp16.onnx" |
| onnx.save(m_fp16, fp16_path.as_posix()) |
| print(f"fp16 -> {fp16_path} ({fp16_path.stat().st_size/1e6:.2f} MB)") |
|
|
| int8_path = OUT / "model_int8.onnx" |
| quantize_dynamic( |
| fp32_path.as_posix(), int8_path.as_posix(), |
| weight_type=QuantType.QInt8, |
| ) |
| print(f"int8 -> {int8_path} ({int8_path.stat().st_size/1e6:.2f} MB)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|