deepshield / scripts /export_onnx.py
ar07xd's picture
Sync from GitHub via hub-sync
fba30db verified
"""P3: Export EfficientNetAutoAttB4 to ONNX for 2-3× CPU inference speedup.
Exports the model to backend/models/efficientnet_autoattb4_dfdc.onnx.
After export, set EFFICIENTNET_ONNX_PATH in .env to enable ONNX inference.
Requirements (install first):
pip install onnx onnxruntime
Usage:
.venv/Scripts/python.exe scripts/export_onnx.py [--validate]
The --validate flag runs a quick numerical comparison between PyTorch and ONNX
outputs on a random face-shaped input to verify the export is correct.
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
import numpy as np
import torch
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
ONNX_OUT = Path(__file__).resolve().parent.parent / "models" / "efficientnet_autoattb4_dfdc.onnx"
def export(out_path: Path, opset: int = 17) -> None:
print("Loading EfficientNetAutoAttB4…")
from services.efficientnet_service import EfficientNetDetector
det = EfficientNetDetector()
net = det.net.eval().cpu()
dummy = torch.zeros(1, 3, 224, 224)
print(f"Exporting to ONNX (opset {opset})…")
out_path.parent.mkdir(parents=True, exist_ok=True)
torch.onnx.export(
net,
dummy,
str(out_path),
opset_version=opset,
input_names=["face"],
output_names=["logit"],
dynamic_axes={"face": {0: "batch"}, "logit": {0: "batch"}},
do_constant_folding=True,
)
size_mb = out_path.stat().st_size / 1024 / 1024
print(f"Saved: {out_path} ({size_mb:.1f} MB)")
def validate(out_path: Path) -> None:
try:
import onnxruntime as ort
except ImportError:
print("onnxruntime not installed — skipping validation. pip install onnxruntime")
return
print("Validating ONNX output vs PyTorch…")
from services.efficientnet_service import EfficientNetDetector
det = EfficientNetDetector()
net = det.net.eval().cpu()
dummy = torch.randn(1, 3, 224, 224)
with torch.inference_mode():
pt_out = net(dummy).numpy()
sess = ort.InferenceSession(str(out_path), providers=["CPUExecutionProvider"])
ort_out = sess.run(None, {"face": dummy.numpy()})[0]
max_diff = float(np.abs(pt_out - ort_out).max())
print(f" Max absolute diff PyTorch vs ONNX: {max_diff:.6f}")
if max_diff < 1e-4:
print(" [PASS] Outputs match within tolerance")
else:
print(" [WARN] Outputs differ more than 1e-4 — inspect export")
# Benchmark.
import time
N = 20
t0 = time.perf_counter()
for _ in range(N):
sess.run(None, {"face": dummy.numpy()})
ort_ms = (time.perf_counter() - t0) / N * 1000
t0 = time.perf_counter()
with torch.inference_mode():
for _ in range(N):
net(dummy)
pt_ms = (time.perf_counter() - t0) / N * 1000
print(f" PyTorch: {pt_ms:.1f} ms/img | ONNX: {ort_ms:.1f} ms/img | speedup: {pt_ms/ort_ms:.2f}×")
print(f"\nTo enable ONNX inference, add to .env:\n EFFICIENTNET_ONNX_PATH={out_path}")
def main() -> int:
parser = argparse.ArgumentParser(description="Export EfficientNetAutoAttB4 to ONNX")
parser.add_argument("--out", type=Path, default=ONNX_OUT, help="Output .onnx file path")
parser.add_argument("--opset", type=int, default=17, help="ONNX opset version (default 17)")
parser.add_argument("--validate", action="store_true", help="Compare ONNX vs PyTorch outputs and benchmark")
args = parser.parse_args()
export(args.out, opset=args.opset)
if args.validate:
validate(args.out)
return 0
if __name__ == "__main__":
raise SystemExit(main())