import torch from model import SiameseNet device = torch.device("cpu") # export on CPU for portability model = SiameseNet(embedding_dim=128) ckpt = torch.load("../checkpoints/best.pt", map_location=device) model.load_state_dict(ckpt["model_state"]) model.eval() # Export the embedding net only (that's all you need at inference) dummy = torch.randn(1, 1, 105, 105) torch.onnx.export( model.embedding_net, dummy, "../checkpoints/siamese_embedding.onnx", input_names = ["image"], output_names = ["embedding"], dynamic_axes = {"image": {0: "batch"}, "embedding": {0: "batch"}}, opset_version = 17, ) print("ONNX model exported → checkpoints/siamese_embedding.onnx") # ── Verify with onnxruntime ─────────────────────────────────── import onnxruntime as ort import numpy as np sess = ort.InferenceSession("../checkpoints/siamese_embedding.onnx") out = sess.run(None, {"image": dummy.numpy()}) print(f"ONNX output shape : {out[0].shape}") # (1, 128) print(f"ONNX output norm : {np.linalg.norm(out[0]):.4f}") # ~1.0 print("ONNX verification passed")