File size: 5,739 Bytes
2979822
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""Export and quantize PyTorch checkpoint to INT8 ONNX."""

import torch
import onnx
import onnxsim
from collections import OrderedDict
import os
import sys
import argparse

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

from src.minifasv2.model import MultiFTNet
from src.minifasv2.config import get_kernel


def load_model_from_checkpoint(checkpoint_path, device, input_size=128):
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)

    if "model_state_dict" in checkpoint:
        state_dict = checkpoint["model_state_dict"]
    elif "state_dict" in checkpoint:
        state_dict = checkpoint["state_dict"]
    else:
        state_dict = checkpoint

    kernel_size = get_kernel(input_size, input_size)
    model = MultiFTNet(
        num_channels=3,
        num_classes=2,
        embedding_size=128,
        conv6_kernel=kernel_size,
    ).to(device)

    new_state_dict = OrderedDict()
    for key, value in state_dict.items():
        new_key = key
        if new_key.startswith("module."):
            new_key = new_key[7:]
        new_key = new_key.replace("model.prob", "model.logits")
        new_key = new_key.replace(".prob", ".logits")
        new_key = new_key.replace("model.drop", "model.dropout")
        new_key = new_key.replace(".drop", ".dropout")
        new_key = new_key.replace("FTGenerator.ft.", "FTGenerator.fourier_transform.")
        new_key = new_key.replace("FTGenerator.ft", "FTGenerator.fourier_transform")
        new_state_dict[new_key] = value

    model.load_state_dict(new_state_dict, strict=False)
    return model


def export_to_onnx(model, output_path, input_size=128):
    print("Exporting model to ONNX...")
    print(f"Output path: {output_path}")

    model.eval()
    dummy_input = torch.randn(1, 3, input_size, input_size)

    torch.onnx.export(
        model,
        dummy_input,
        output_path,
        input_names=["input"],
        output_names=["output"],
        export_params=True,
        opset_version=13,
        do_constant_folding=True,
        dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
    )

    onnx_model = onnx.load(output_path)
    print("Simplifying ONNX model...")
    onnx_model, check = onnxsim.simplify(onnx_model)
    assert check, "Simplified ONNX model could not be validated"
    onnx.save(onnx_model, output_path)

    print("[OK] ONNX model exported")
    return output_path


def quantize_onnx_with_ort(onnx_path, output_path):
    try:
        from onnxruntime.quantization import quantize_dynamic, QuantType

        print("\nQuantizing ONNX model with ONNX Runtime...")
        print(f"Input: {onnx_path}")
        print(f"Output: {output_path}")

        quantize_dynamic(
            model_input=onnx_path,
            model_output=output_path,
            weight_type=QuantType.QUInt8,
        )

        print("[OK] Quantized ONNX model created")
        return output_path
    except ImportError:
        print(
            "[ERROR] onnxruntime not installed. Install with: pip install onnxruntime"
        )
        return None
    except Exception as e:
        print(f"[ERROR] Quantization failed: {e}")
        return None


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Export model to ONNX and quantize it using ONNX Runtime"
    )
    parser.add_argument("checkpoint_path", type=str, help="Path to .pth checkpoint")
    parser.add_argument(
        "--input_size", type=int, default=128, help="Input image size (default: 128)"
    )
    parser.add_argument(
        "--output_onnx",
        type=str,
        default=None,
        help="Path to save regular .onnx (default: replaces .pth with .onnx)",
    )
    parser.add_argument(
        "--output_quantized",
        type=str,
        default=None,
        help="Path to save quantized .onnx (default: adds _quantized suffix)",
    )
    parser.add_argument(
        "--skip_regular",
        action="store_true",
        help="Skip exporting regular ONNX if it already exists",
    )

    args = parser.parse_args()

    assert os.path.isfile(
        args.checkpoint_path
    ), f"Checkpoint not found: {args.checkpoint_path}"

    device = "cpu"
    print(f"Using device: {device}")

    print(f"\nLoading model from {args.checkpoint_path}...")
    model = load_model_from_checkpoint(args.checkpoint_path, device, args.input_size)
    print("[OK] Model loaded")

    if args.output_onnx is None:
        args.output_onnx = args.checkpoint_path.replace(".pth", ".onnx")

    if not args.skip_regular or not os.path.exists(args.output_onnx):
        export_to_onnx(model, args.output_onnx, args.input_size)
        onnx_size = os.path.getsize(args.output_onnx) / (1024 * 1024)
        print(f"Regular ONNX size: {onnx_size:.2f} MB")
    else:
        print(f"Using existing ONNX: {args.output_onnx}")

    if args.output_quantized is None:
        args.output_quantized = args.checkpoint_path.replace(".pth", "_quantized.onnx")

    result = quantize_onnx_with_ort(args.output_onnx, args.output_quantized)

    if result:
        quantized_size = os.path.getsize(args.output_quantized) / (1024 * 1024)
        onnx_size = os.path.getsize(args.output_onnx) / (1024 * 1024)
        print(f"\nQuantized ONNX size: {quantized_size:.2f} MB")
        print(f"Size reduction: {quantized_size/onnx_size*100:.1f}% of original")
        print(f"\n[OK] Done! Quantized ONNX saved: {args.output_quantized}")
    else:
        print(
            "\n[WARNING] Quantization failed. Regular ONNX is available at:",
            args.output_onnx,
        )
        print(
            "For regular ONNX export only, use: python scripts/export_onnx.py <checkpoint>"
        )