TwinLiteNet8 / export_onnx.py
WEN0256's picture
Initial release: TwinLiteNet8 (0.44M params, 7-class orchard semantic seg, edge-deployment ready)
f5cc6c0 verified
"""Export TwinLiteNet8 to ONNX for cross-platform deployment.
Usage:
python export_onnx.py --ckpt run_8class/twinlite8_best.pt --out twinlite8.onnx
python export_onnx.py --ckpt run_8class/twinlite8_best.pt --out twinlite8_dynamic.onnx --dynamic
"""
import argparse, sys, os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from pathlib import Path
import numpy as np, torch
from model.TwinLite_8class import TwinLiteNet8
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", required=True)
ap.add_argument("--out", required=True)
ap.add_argument("--height", type=int, default=360)
ap.add_argument("--width", type=int, default=640)
ap.add_argument("--dynamic", action="store_true",
help="Allow dynamic batch + spatial dims (slightly slower at runtime)")
ap.add_argument("--opset", type=int, default=17)
args = ap.parse_args()
print(f"loading ckpt: {args.ckpt}")
model = TwinLiteNet8(num_classes=8).eval()
ckpt = torch.load(args.ckpt, map_location="cpu", weights_only=False)
model.load_state_dict(ckpt["model"])
print(f" epoch {ckpt['epoch']} tree IoU {ckpt.get('tree_iou_old','?')}")
dummy = torch.randn(1, 3, args.height, args.width)
if args.dynamic:
dyn = {"input": {0: "batch", 2: "height", 3: "width"},
"output": {0: "batch", 2: "height", 3: "width"}}
else:
dyn = None
print(f"exporting to ONNX (opset {args.opset}) ...")
torch.onnx.export(
model, dummy, args.out,
input_names=["input"], output_names=["output"],
dynamic_axes=dyn,
opset_version=args.opset,
do_constant_folding=True,
)
sz = os.path.getsize(args.out) / 1e6
print(f" saved: {args.out} ({sz:.2f} MB)")
# Validate ONNX numerical parity vs PyTorch
try:
import onnxruntime as ort
sess = ort.InferenceSession(args.out, providers=["CPUExecutionProvider"])
with torch.no_grad():
torch_out = model(dummy).numpy()
onnx_out = sess.run(None, {"input": dummy.numpy()})[0]
diff = np.abs(torch_out - onnx_out)
argmax_match = (torch_out.argmax(1) == onnx_out.argmax(1)).mean()
print(f" parity: max_abs_diff={diff.max():.6f} mean={diff.mean():.6f}")
print(f" argmax agreement: {100*argmax_match:.4f}% (must be ~100% for safe deploy)")
assert argmax_match > 0.999, "argmax disagreement > 0.1% — investigate"
print(" PARITY OK")
except ImportError:
print(" (skip parity check — onnxruntime not installed; pip install onnxruntime)")
if __name__ == "__main__":
main()