rangoli-classifier / scripts /export_onnx.py
shashidharak99's picture
Upload 16 files
0b3dd07 verified
"""
============================================================
Export Model to ONNX for Deployment
============================================================
Usage:
python scripts/export_onnx.py --checkpoint results/checkpoints/resnet50_best.pth
============================================================
"""
import os
import sys
import yaml
import argparse
import torch
import torch.onnx
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.classifier import build_model
def export_to_onnx(checkpoint_path, config_path="configs/config.yaml", output_path=None):
with open(config_path) as f:
config = yaml.safe_load(f)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model_name = checkpoint["model_name"]
model = build_model(model_name, config)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
if output_path is None:
output_path = checkpoint_path.replace(".pth", ".onnx")
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model, dummy_input, output_path,
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size"},
"output": {0: "batch_size"},
},
)
# Verify
import onnx
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
size_mb = os.path.getsize(output_path) / (1024 ** 2)
print(f"Exported: {output_path} ({size_mb:.1f} MB)")
# Test with ONNX Runtime
try:
import onnxruntime as ort
session = ort.InferenceSession(output_path)
result = session.run(None, {"input": dummy_input.numpy()})
print(f"ONNX Runtime test: output shape = {result[0].shape}")
except ImportError:
print("Install onnxruntime to verify: pip install onnxruntime")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", type=str, required=True)
parser.add_argument("--config", type=str, default="configs/config.yaml")
parser.add_argument("--output", type=str, default=None)
args = parser.parse_args()
export_to_onnx(args.checkpoint, args.config, args.output)