| """Export TwinLiteNet8 to ONNX for cross-platform deployment. |
| |
| Usage: |
| python export_onnx.py --ckpt run_8class/twinlite8_best.pt --out twinlite8.onnx |
| python export_onnx.py --ckpt run_8class/twinlite8_best.pt --out twinlite8_dynamic.onnx --dynamic |
| """ |
| import argparse, sys, os |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from pathlib import Path |
| import numpy as np, torch |
| from model.TwinLite_8class import TwinLiteNet8 |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--ckpt", required=True) |
| ap.add_argument("--out", required=True) |
| ap.add_argument("--height", type=int, default=360) |
| ap.add_argument("--width", type=int, default=640) |
| ap.add_argument("--dynamic", action="store_true", |
| help="Allow dynamic batch + spatial dims (slightly slower at runtime)") |
| ap.add_argument("--opset", type=int, default=17) |
| args = ap.parse_args() |
|
|
| print(f"loading ckpt: {args.ckpt}") |
| model = TwinLiteNet8(num_classes=8).eval() |
| ckpt = torch.load(args.ckpt, map_location="cpu", weights_only=False) |
| model.load_state_dict(ckpt["model"]) |
| print(f" epoch {ckpt['epoch']} tree IoU {ckpt.get('tree_iou_old','?')}") |
|
|
| dummy = torch.randn(1, 3, args.height, args.width) |
|
|
| if args.dynamic: |
| dyn = {"input": {0: "batch", 2: "height", 3: "width"}, |
| "output": {0: "batch", 2: "height", 3: "width"}} |
| else: |
| dyn = None |
|
|
| print(f"exporting to ONNX (opset {args.opset}) ...") |
| torch.onnx.export( |
| model, dummy, args.out, |
| input_names=["input"], output_names=["output"], |
| dynamic_axes=dyn, |
| opset_version=args.opset, |
| do_constant_folding=True, |
| ) |
|
|
| sz = os.path.getsize(args.out) / 1e6 |
| print(f" saved: {args.out} ({sz:.2f} MB)") |
|
|
| |
| try: |
| import onnxruntime as ort |
| sess = ort.InferenceSession(args.out, providers=["CPUExecutionProvider"]) |
| with torch.no_grad(): |
| torch_out = model(dummy).numpy() |
| onnx_out = sess.run(None, {"input": dummy.numpy()})[0] |
| diff = np.abs(torch_out - onnx_out) |
| argmax_match = (torch_out.argmax(1) == onnx_out.argmax(1)).mean() |
| print(f" parity: max_abs_diff={diff.max():.6f} mean={diff.mean():.6f}") |
| print(f" argmax agreement: {100*argmax_match:.4f}% (must be ~100% for safe deploy)") |
| assert argmax_match > 0.999, "argmax disagreement > 0.1% — investigate" |
| print(" PARITY OK") |
| except ImportError: |
| print(" (skip parity check — onnxruntime not installed; pip install onnxruntime)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|