File size: 3,654 Bytes
fba30db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""P3: Export EfficientNetAutoAttB4 to ONNX for 2-3× CPU inference speedup.

Exports the model to backend/models/efficientnet_autoattb4_dfdc.onnx.
After export, set EFFICIENTNET_ONNX_PATH in .env to enable ONNX inference.

Requirements (install first):
    pip install onnx onnxruntime

Usage:
    .venv/Scripts/python.exe scripts/export_onnx.py [--validate]

The --validate flag runs a quick numerical comparison between PyTorch and ONNX
outputs on a random face-shaped input to verify the export is correct.
"""
from __future__ import annotations

import argparse
import sys
from pathlib import Path

import numpy as np
import torch

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

ONNX_OUT = Path(__file__).resolve().parent.parent / "models" / "efficientnet_autoattb4_dfdc.onnx"


def export(out_path: Path, opset: int = 17) -> None:
    print("Loading EfficientNetAutoAttB4…")
    from services.efficientnet_service import EfficientNetDetector
    det = EfficientNetDetector()
    net = det.net.eval().cpu()

    dummy = torch.zeros(1, 3, 224, 224)

    print(f"Exporting to ONNX (opset {opset})…")
    out_path.parent.mkdir(parents=True, exist_ok=True)
    torch.onnx.export(
        net,
        dummy,
        str(out_path),
        opset_version=opset,
        input_names=["face"],
        output_names=["logit"],
        dynamic_axes={"face": {0: "batch"}, "logit": {0: "batch"}},
        do_constant_folding=True,
    )
    size_mb = out_path.stat().st_size / 1024 / 1024
    print(f"Saved: {out_path}  ({size_mb:.1f} MB)")


def validate(out_path: Path) -> None:
    try:
        import onnxruntime as ort
    except ImportError:
        print("onnxruntime not installed — skipping validation. pip install onnxruntime")
        return

    print("Validating ONNX output vs PyTorch…")
    from services.efficientnet_service import EfficientNetDetector
    det = EfficientNetDetector()
    net = det.net.eval().cpu()

    dummy = torch.randn(1, 3, 224, 224)
    with torch.inference_mode():
        pt_out = net(dummy).numpy()

    sess = ort.InferenceSession(str(out_path), providers=["CPUExecutionProvider"])
    ort_out = sess.run(None, {"face": dummy.numpy()})[0]

    max_diff = float(np.abs(pt_out - ort_out).max())
    print(f"  Max absolute diff PyTorch vs ONNX: {max_diff:.6f}")
    if max_diff < 1e-4:
        print("  [PASS] Outputs match within tolerance")
    else:
        print("  [WARN] Outputs differ more than 1e-4 — inspect export")

    # Benchmark.
    import time
    N = 20
    t0 = time.perf_counter()
    for _ in range(N):
        sess.run(None, {"face": dummy.numpy()})
    ort_ms = (time.perf_counter() - t0) / N * 1000

    t0 = time.perf_counter()
    with torch.inference_mode():
        for _ in range(N):
            net(dummy)
    pt_ms = (time.perf_counter() - t0) / N * 1000

    print(f"  PyTorch: {pt_ms:.1f} ms/img  |  ONNX: {ort_ms:.1f} ms/img  |  speedup: {pt_ms/ort_ms:.2f}×")
    print(f"\nTo enable ONNX inference, add to .env:\n  EFFICIENTNET_ONNX_PATH={out_path}")


def main() -> int:
    parser = argparse.ArgumentParser(description="Export EfficientNetAutoAttB4 to ONNX")
    parser.add_argument("--out", type=Path, default=ONNX_OUT, help="Output .onnx file path")
    parser.add_argument("--opset", type=int, default=17, help="ONNX opset version (default 17)")
    parser.add_argument("--validate", action="store_true", help="Compare ONNX vs PyTorch outputs and benchmark")
    args = parser.parse_args()

    export(args.out, opset=args.opset)
    if args.validate:
        validate(args.out)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())