"""Export TwinLiteNet8 to ONNX for cross-platform deployment. Usage: python export_onnx.py --ckpt run_8class/twinlite8_best.pt --out twinlite8.onnx python export_onnx.py --ckpt run_8class/twinlite8_best.pt --out twinlite8_dynamic.onnx --dynamic """ import argparse, sys, os sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from pathlib import Path import numpy as np, torch from model.TwinLite_8class import TwinLiteNet8 def main(): ap = argparse.ArgumentParser() ap.add_argument("--ckpt", required=True) ap.add_argument("--out", required=True) ap.add_argument("--height", type=int, default=360) ap.add_argument("--width", type=int, default=640) ap.add_argument("--dynamic", action="store_true", help="Allow dynamic batch + spatial dims (slightly slower at runtime)") ap.add_argument("--opset", type=int, default=17) args = ap.parse_args() print(f"loading ckpt: {args.ckpt}") model = TwinLiteNet8(num_classes=8).eval() ckpt = torch.load(args.ckpt, map_location="cpu", weights_only=False) model.load_state_dict(ckpt["model"]) print(f" epoch {ckpt['epoch']} tree IoU {ckpt.get('tree_iou_old','?')}") dummy = torch.randn(1, 3, args.height, args.width) if args.dynamic: dyn = {"input": {0: "batch", 2: "height", 3: "width"}, "output": {0: "batch", 2: "height", 3: "width"}} else: dyn = None print(f"exporting to ONNX (opset {args.opset}) ...") torch.onnx.export( model, dummy, args.out, input_names=["input"], output_names=["output"], dynamic_axes=dyn, opset_version=args.opset, do_constant_folding=True, ) sz = os.path.getsize(args.out) / 1e6 print(f" saved: {args.out} ({sz:.2f} MB)") # Validate ONNX numerical parity vs PyTorch try: import onnxruntime as ort sess = ort.InferenceSession(args.out, providers=["CPUExecutionProvider"]) with torch.no_grad(): torch_out = model(dummy).numpy() onnx_out = sess.run(None, {"input": dummy.numpy()})[0] diff = np.abs(torch_out - onnx_out) argmax_match = (torch_out.argmax(1) == onnx_out.argmax(1)).mean() print(f" parity: max_abs_diff={diff.max():.6f} mean={diff.mean():.6f}") print(f" argmax agreement: {100*argmax_match:.4f}% (must be ~100% for safe deploy)") assert argmax_match > 0.999, "argmax disagreement > 0.1% — investigate" print(" PARITY OK") except ImportError: print(" (skip parity check — onnxruntime not installed; pip install onnxruntime)") if __name__ == "__main__": main()