File size: 1,167 Bytes
02ac88d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | import torch
from model import SiameseNet
device = torch.device("cpu") # export on CPU for portability
model = SiameseNet(embedding_dim=128)
ckpt = torch.load("../checkpoints/best.pt", map_location=device)
model.load_state_dict(ckpt["model_state"])
model.eval()
# Export the embedding net only (that's all you need at inference)
dummy = torch.randn(1, 1, 105, 105)
torch.onnx.export(
model.embedding_net,
dummy,
"../checkpoints/siamese_embedding.onnx",
input_names = ["image"],
output_names = ["embedding"],
dynamic_axes = {"image": {0: "batch"}, "embedding": {0: "batch"}},
opset_version = 17,
)
print("ONNX model exported β checkpoints/siamese_embedding.onnx")
# ββ Verify with onnxruntime βββββββββββββββββββββββββββββββββββ
import onnxruntime as ort
import numpy as np
sess = ort.InferenceSession("../checkpoints/siamese_embedding.onnx")
out = sess.run(None, {"image": dummy.numpy()})
print(f"ONNX output shape : {out[0].shape}") # (1, 128)
print(f"ONNX output norm : {np.linalg.norm(out[0]):.4f}") # ~1.0
print("ONNX verification passed") |