IIIT-5K / export.py
Tejveer12's picture
Upload 6 files
dca9ee4 verified
import os
import torch
import torch.onnx
import onnx
import onnxruntime as ort
from onnxruntime.quantization import quantize_dynamic, QuantType
from crnn import CRNN
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
def export_onnx(
model_path: str,
onnx_path: str,
quantize: bool = True,
):
assert os.path.exists(model_path), f"Model not found: {model_path}"
# -------------------------------------------------
# 1. Initialize model
# -------------------------------------------------
model = CRNN(1, len(LETTERS) + 1)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
print("βœ… Model loaded")
# -------------------------------------------------
# 2. Dummy input (B, C, H, W)
# -------------------------------------------------
dummy_input = torch.randn(1, 1, 32, 100)
# -------------------------------------------------
# 3. Export to ONNX
# -------------------------------------------------
torch.onnx.export(
model,
dummy_input,
onnx_path,
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=["input"],
output_names=["logits"],
dynamic_axes={
"input": {
0: "batch_size",
3: "width", # variable width for text
},
"logits": {
1: "batch_size",
0: "sequence_length",
},
},
)
print(f"πŸ“¦ ONNX model exported β†’ {onnx_path}")
# -------------------------------------------------
# 4. Validate ONNX
# -------------------------------------------------
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print("βœ… ONNX model validation passed")
# -------------------------------------------------
# 5. Dynamic Quantization (CPU)
# -------------------------------------------------
if quantize:
quantized_path = onnx_path.replace(".onnx", "_quantized.onnx")
quantize_dynamic(
model_input=onnx_path,
model_output=quantized_path,
weight_type=QuantType.QUInt8,
)
print(f"⚑ Quantized model saved β†’ {quantized_path}")
# Quick runtime sanity check
sess = ort.InferenceSession(quantized_path)
print("βœ… Quantized ONNX Runtime session OK")
print("πŸŽ‰ Export pipeline completed successfully")
if __name__ == "__main__":
export_onnx(
model_path="fix_width_crnn.pth",
onnx_path="crnn.onnx",
quantize=True,
)