| import os |
| import torch |
| import torch.onnx |
| import onnx |
| import onnxruntime as ort |
| from onnxruntime.quantization import quantize_dynamic, QuantType |
|
|
| from crnn import CRNN |
|
|
|
|
| LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" |
|
|
|
|
| def export_onnx( |
| model_path: str, |
| onnx_path: str, |
| quantize: bool = True, |
| ): |
| assert os.path.exists(model_path), f"Model not found: {model_path}" |
|
|
| |
| |
| |
| model = CRNN(1, len(LETTERS) + 1) |
| model.load_state_dict(torch.load(model_path, map_location="cpu")) |
| model.eval() |
|
|
| print("β
Model loaded") |
|
|
| |
| |
| |
| dummy_input = torch.randn(1, 1, 32, 100) |
|
|
| |
| |
| |
| torch.onnx.export( |
| model, |
| dummy_input, |
| onnx_path, |
| export_params=True, |
| opset_version=12, |
| do_constant_folding=True, |
| input_names=["input"], |
| output_names=["logits"], |
| dynamic_axes={ |
| "input": { |
| 0: "batch_size", |
| 3: "width", |
| }, |
| "logits": { |
| 1: "batch_size", |
| 0: "sequence_length", |
| }, |
| }, |
| ) |
|
|
| print(f"π¦ ONNX model exported β {onnx_path}") |
|
|
| |
| |
| |
| onnx_model = onnx.load(onnx_path) |
| onnx.checker.check_model(onnx_model) |
| print("β
ONNX model validation passed") |
|
|
| |
| |
| |
| if quantize: |
| quantized_path = onnx_path.replace(".onnx", "_quantized.onnx") |
|
|
| quantize_dynamic( |
| model_input=onnx_path, |
| model_output=quantized_path, |
| weight_type=QuantType.QUInt8, |
| ) |
|
|
| print(f"β‘ Quantized model saved β {quantized_path}") |
|
|
| |
| sess = ort.InferenceSession(quantized_path) |
| print("β
Quantized ONNX Runtime session OK") |
|
|
| print("π Export pipeline completed successfully") |
|
|
|
|
| if __name__ == "__main__": |
| export_onnx( |
| model_path="fix_width_crnn.pth", |
| onnx_path="crnn.onnx", |
| quantize=True, |
| ) |
|
|