""" ============================================================ Export Model to ONNX for Deployment ============================================================ Usage: python scripts/export_onnx.py --checkpoint results/checkpoints/resnet50_best.pth ============================================================ """ import os import sys import yaml import argparse import torch import torch.onnx sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.classifier import build_model def export_to_onnx(checkpoint_path, config_path="configs/config.yaml", output_path=None): with open(config_path) as f: config = yaml.safe_load(f) checkpoint = torch.load(checkpoint_path, map_location="cpu") model_name = checkpoint["model_name"] model = build_model(model_name, config) model.load_state_dict(checkpoint["state_dict"]) model.eval() if output_path is None: output_path = checkpoint_path.replace(".pth", ".onnx") dummy_input = torch.randn(1, 3, 224, 224) torch.onnx.export( model, dummy_input, output_path, export_params=True, opset_version=14, do_constant_folding=True, input_names=["input"], output_names=["output"], dynamic_axes={ "input": {0: "batch_size"}, "output": {0: "batch_size"}, }, ) # Verify import onnx onnx_model = onnx.load(output_path) onnx.checker.check_model(onnx_model) size_mb = os.path.getsize(output_path) / (1024 ** 2) print(f"Exported: {output_path} ({size_mb:.1f} MB)") # Test with ONNX Runtime try: import onnxruntime as ort session = ort.InferenceSession(output_path) result = session.run(None, {"input": dummy_input.numpy()}) print(f"ONNX Runtime test: output shape = {result[0].shape}") except ImportError: print("Install onnxruntime to verify: pip install onnxruntime") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", type=str, required=True) parser.add_argument("--config", type=str, default="configs/config.yaml") parser.add_argument("--output", type=str, default=None) args = parser.parse_args() export_to_onnx(args.checkpoint, args.config, args.output)