LETTER / src /export_onnx.py
Sharath33's picture
Upload folder using huggingface_hub
02ac88d verified
import torch
from model import SiameseNet
device = torch.device("cpu") # export on CPU for portability
model = SiameseNet(embedding_dim=128)
ckpt = torch.load("../checkpoints/best.pt", map_location=device)
model.load_state_dict(ckpt["model_state"])
model.eval()
# Export the embedding net only (that's all you need at inference)
dummy = torch.randn(1, 1, 105, 105)
torch.onnx.export(
model.embedding_net,
dummy,
"../checkpoints/siamese_embedding.onnx",
input_names = ["image"],
output_names = ["embedding"],
dynamic_axes = {"image": {0: "batch"}, "embedding": {0: "batch"}},
opset_version = 17,
)
print("ONNX model exported β†’ checkpoints/siamese_embedding.onnx")
# ── Verify with onnxruntime ───────────────────────────────────
import onnxruntime as ort
import numpy as np
sess = ort.InferenceSession("../checkpoints/siamese_embedding.onnx")
out = sess.run(None, {"image": dummy.numpy()})
print(f"ONNX output shape : {out[0].shape}") # (1, 128)
print(f"ONNX output norm : {np.linalg.norm(out[0]):.4f}") # ~1.0
print("ONNX verification passed")