| import torch | |
| from model import SiameseNet | |
| device = torch.device("cpu") # export on CPU for portability | |
| model = SiameseNet(embedding_dim=128) | |
| ckpt = torch.load("../checkpoints/best.pt", map_location=device) | |
| model.load_state_dict(ckpt["model_state"]) | |
| model.eval() | |
| # Export the embedding net only (that's all you need at inference) | |
| dummy = torch.randn(1, 1, 105, 105) | |
| torch.onnx.export( | |
| model.embedding_net, | |
| dummy, | |
| "../checkpoints/siamese_embedding.onnx", | |
| input_names = ["image"], | |
| output_names = ["embedding"], | |
| dynamic_axes = {"image": {0: "batch"}, "embedding": {0: "batch"}}, | |
| opset_version = 17, | |
| ) | |
| print("ONNX model exported β checkpoints/siamese_embedding.onnx") | |
| # ββ Verify with onnxruntime βββββββββββββββββββββββββββββββββββ | |
| import onnxruntime as ort | |
| import numpy as np | |
| sess = ort.InferenceSession("../checkpoints/siamese_embedding.onnx") | |
| out = sess.run(None, {"image": dummy.numpy()}) | |
| print(f"ONNX output shape : {out[0].shape}") # (1, 128) | |
| print(f"ONNX output norm : {np.linalg.norm(out[0]):.4f}") # ~1.0 | |
| print("ONNX verification passed") |