Spaces:
Runtime error
Runtime error
| """ | |
| ============================================================ | |
| Export Model to ONNX for Deployment | |
| ============================================================ | |
| Usage: | |
| python scripts/export_onnx.py --checkpoint results/checkpoints/resnet50_best.pth | |
| ============================================================ | |
| """ | |
| import os | |
| import sys | |
| import yaml | |
| import argparse | |
| import torch | |
| import torch.onnx | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from models.classifier import build_model | |
| def export_to_onnx(checkpoint_path, config_path="configs/config.yaml", output_path=None): | |
| with open(config_path) as f: | |
| config = yaml.safe_load(f) | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
| model_name = checkpoint["model_name"] | |
| model = build_model(model_name, config) | |
| model.load_state_dict(checkpoint["state_dict"]) | |
| model.eval() | |
| if output_path is None: | |
| output_path = checkpoint_path.replace(".pth", ".onnx") | |
| dummy_input = torch.randn(1, 3, 224, 224) | |
| torch.onnx.export( | |
| model, dummy_input, output_path, | |
| export_params=True, | |
| opset_version=14, | |
| do_constant_folding=True, | |
| input_names=["input"], | |
| output_names=["output"], | |
| dynamic_axes={ | |
| "input": {0: "batch_size"}, | |
| "output": {0: "batch_size"}, | |
| }, | |
| ) | |
| # Verify | |
| import onnx | |
| onnx_model = onnx.load(output_path) | |
| onnx.checker.check_model(onnx_model) | |
| size_mb = os.path.getsize(output_path) / (1024 ** 2) | |
| print(f"Exported: {output_path} ({size_mb:.1f} MB)") | |
| # Test with ONNX Runtime | |
| try: | |
| import onnxruntime as ort | |
| session = ort.InferenceSession(output_path) | |
| result = session.run(None, {"input": dummy_input.numpy()}) | |
| print(f"ONNX Runtime test: output shape = {result[0].shape}") | |
| except ImportError: | |
| print("Install onnxruntime to verify: pip install onnxruntime") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--checkpoint", type=str, required=True) | |
| parser.add_argument("--config", type=str, default="configs/config.yaml") | |
| parser.add_argument("--output", type=str, default=None) | |
| args = parser.parse_args() | |
| export_to_onnx(args.checkpoint, args.config, args.output) | |