File size: 5,739 Bytes
2979822 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
"""Export and quantize PyTorch checkpoint to INT8 ONNX."""
import torch
import onnx
import onnxsim
from collections import OrderedDict
import os
import sys
import argparse
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from src.minifasv2.model import MultiFTNet
from src.minifasv2.config import get_kernel
def load_model_from_checkpoint(checkpoint_path, device, input_size=128):
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
elif "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
kernel_size = get_kernel(input_size, input_size)
model = MultiFTNet(
num_channels=3,
num_classes=2,
embedding_size=128,
conv6_kernel=kernel_size,
).to(device)
new_state_dict = OrderedDict()
for key, value in state_dict.items():
new_key = key
if new_key.startswith("module."):
new_key = new_key[7:]
new_key = new_key.replace("model.prob", "model.logits")
new_key = new_key.replace(".prob", ".logits")
new_key = new_key.replace("model.drop", "model.dropout")
new_key = new_key.replace(".drop", ".dropout")
new_key = new_key.replace("FTGenerator.ft.", "FTGenerator.fourier_transform.")
new_key = new_key.replace("FTGenerator.ft", "FTGenerator.fourier_transform")
new_state_dict[new_key] = value
model.load_state_dict(new_state_dict, strict=False)
return model
def export_to_onnx(model, output_path, input_size=128):
print("Exporting model to ONNX...")
print(f"Output path: {output_path}")
model.eval()
dummy_input = torch.randn(1, 3, input_size, input_size)
torch.onnx.export(
model,
dummy_input,
output_path,
input_names=["input"],
output_names=["output"],
export_params=True,
opset_version=13,
do_constant_folding=True,
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)
onnx_model = onnx.load(output_path)
print("Simplifying ONNX model...")
onnx_model, check = onnxsim.simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
onnx.save(onnx_model, output_path)
print("[OK] ONNX model exported")
return output_path
def quantize_onnx_with_ort(onnx_path, output_path):
try:
from onnxruntime.quantization import quantize_dynamic, QuantType
print("\nQuantizing ONNX model with ONNX Runtime...")
print(f"Input: {onnx_path}")
print(f"Output: {output_path}")
quantize_dynamic(
model_input=onnx_path,
model_output=output_path,
weight_type=QuantType.QUInt8,
)
print("[OK] Quantized ONNX model created")
return output_path
except ImportError:
print(
"[ERROR] onnxruntime not installed. Install with: pip install onnxruntime"
)
return None
except Exception as e:
print(f"[ERROR] Quantization failed: {e}")
return None
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Export model to ONNX and quantize it using ONNX Runtime"
)
parser.add_argument("checkpoint_path", type=str, help="Path to .pth checkpoint")
parser.add_argument(
"--input_size", type=int, default=128, help="Input image size (default: 128)"
)
parser.add_argument(
"--output_onnx",
type=str,
default=None,
help="Path to save regular .onnx (default: replaces .pth with .onnx)",
)
parser.add_argument(
"--output_quantized",
type=str,
default=None,
help="Path to save quantized .onnx (default: adds _quantized suffix)",
)
parser.add_argument(
"--skip_regular",
action="store_true",
help="Skip exporting regular ONNX if it already exists",
)
args = parser.parse_args()
assert os.path.isfile(
args.checkpoint_path
), f"Checkpoint not found: {args.checkpoint_path}"
device = "cpu"
print(f"Using device: {device}")
print(f"\nLoading model from {args.checkpoint_path}...")
model = load_model_from_checkpoint(args.checkpoint_path, device, args.input_size)
print("[OK] Model loaded")
if args.output_onnx is None:
args.output_onnx = args.checkpoint_path.replace(".pth", ".onnx")
if not args.skip_regular or not os.path.exists(args.output_onnx):
export_to_onnx(model, args.output_onnx, args.input_size)
onnx_size = os.path.getsize(args.output_onnx) / (1024 * 1024)
print(f"Regular ONNX size: {onnx_size:.2f} MB")
else:
print(f"Using existing ONNX: {args.output_onnx}")
if args.output_quantized is None:
args.output_quantized = args.checkpoint_path.replace(".pth", "_quantized.onnx")
result = quantize_onnx_with_ort(args.output_onnx, args.output_quantized)
if result:
quantized_size = os.path.getsize(args.output_quantized) / (1024 * 1024)
onnx_size = os.path.getsize(args.output_onnx) / (1024 * 1024)
print(f"\nQuantized ONNX size: {quantized_size:.2f} MB")
print(f"Size reduction: {quantized_size/onnx_size*100:.1f}% of original")
print(f"\n[OK] Done! Quantized ONNX saved: {args.output_quantized}")
else:
print(
"\n[WARNING] Quantization failed. Regular ONNX is available at:",
args.output_onnx,
)
print(
"For regular ONNX export only, use: python scripts/export_onnx.py <checkpoint>"
)
|