File size: 1,167 Bytes
02ac88d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from model import SiameseNet

device = torch.device("cpu")   # export on CPU for portability
model  = SiameseNet(embedding_dim=128)
ckpt   = torch.load("../checkpoints/best.pt", map_location=device)
model.load_state_dict(ckpt["model_state"])
model.eval()

# Export the embedding net only (that's all you need at inference)
dummy = torch.randn(1, 1, 105, 105)

torch.onnx.export(
    model.embedding_net,
    dummy,
    "../checkpoints/siamese_embedding.onnx",
    input_names  = ["image"],
    output_names = ["embedding"],
    dynamic_axes = {"image": {0: "batch"}, "embedding": {0: "batch"}},
    opset_version = 17,
)
print("ONNX model exported β†’ checkpoints/siamese_embedding.onnx")

# ── Verify with onnxruntime ───────────────────────────────────
import onnxruntime as ort
import numpy as np

sess = ort.InferenceSession("../checkpoints/siamese_embedding.onnx")
out  = sess.run(None, {"image": dummy.numpy()})
print(f"ONNX output shape  : {out[0].shape}")     # (1, 128)
print(f"ONNX output norm   : {np.linalg.norm(out[0]):.4f}")  # ~1.0
print("ONNX verification passed")