sol9x-sagar's picture
initial setup
2979822
"""Export and quantize PyTorch checkpoint to INT8 ONNX."""
import torch
import onnx
import onnxsim
from collections import OrderedDict
import os
import sys
import argparse
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from src.minifasv2.model import MultiFTNet
from src.minifasv2.config import get_kernel
def load_model_from_checkpoint(checkpoint_path, device, input_size=128):
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
elif "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
kernel_size = get_kernel(input_size, input_size)
model = MultiFTNet(
num_channels=3,
num_classes=2,
embedding_size=128,
conv6_kernel=kernel_size,
).to(device)
new_state_dict = OrderedDict()
for key, value in state_dict.items():
new_key = key
if new_key.startswith("module."):
new_key = new_key[7:]
new_key = new_key.replace("model.prob", "model.logits")
new_key = new_key.replace(".prob", ".logits")
new_key = new_key.replace("model.drop", "model.dropout")
new_key = new_key.replace(".drop", ".dropout")
new_key = new_key.replace("FTGenerator.ft.", "FTGenerator.fourier_transform.")
new_key = new_key.replace("FTGenerator.ft", "FTGenerator.fourier_transform")
new_state_dict[new_key] = value
model.load_state_dict(new_state_dict, strict=False)
return model
def export_to_onnx(model, output_path, input_size=128):
print("Exporting model to ONNX...")
print(f"Output path: {output_path}")
model.eval()
dummy_input = torch.randn(1, 3, input_size, input_size)
torch.onnx.export(
model,
dummy_input,
output_path,
input_names=["input"],
output_names=["output"],
export_params=True,
opset_version=13,
do_constant_folding=True,
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)
onnx_model = onnx.load(output_path)
print("Simplifying ONNX model...")
onnx_model, check = onnxsim.simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
onnx.save(onnx_model, output_path)
print("[OK] ONNX model exported")
return output_path
def quantize_onnx_with_ort(onnx_path, output_path):
try:
from onnxruntime.quantization import quantize_dynamic, QuantType
print("\nQuantizing ONNX model with ONNX Runtime...")
print(f"Input: {onnx_path}")
print(f"Output: {output_path}")
quantize_dynamic(
model_input=onnx_path,
model_output=output_path,
weight_type=QuantType.QUInt8,
)
print("[OK] Quantized ONNX model created")
return output_path
except ImportError:
print(
"[ERROR] onnxruntime not installed. Install with: pip install onnxruntime"
)
return None
except Exception as e:
print(f"[ERROR] Quantization failed: {e}")
return None
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Export model to ONNX and quantize it using ONNX Runtime"
)
parser.add_argument("checkpoint_path", type=str, help="Path to .pth checkpoint")
parser.add_argument(
"--input_size", type=int, default=128, help="Input image size (default: 128)"
)
parser.add_argument(
"--output_onnx",
type=str,
default=None,
help="Path to save regular .onnx (default: replaces .pth with .onnx)",
)
parser.add_argument(
"--output_quantized",
type=str,
default=None,
help="Path to save quantized .onnx (default: adds _quantized suffix)",
)
parser.add_argument(
"--skip_regular",
action="store_true",
help="Skip exporting regular ONNX if it already exists",
)
args = parser.parse_args()
assert os.path.isfile(
args.checkpoint_path
), f"Checkpoint not found: {args.checkpoint_path}"
device = "cpu"
print(f"Using device: {device}")
print(f"\nLoading model from {args.checkpoint_path}...")
model = load_model_from_checkpoint(args.checkpoint_path, device, args.input_size)
print("[OK] Model loaded")
if args.output_onnx is None:
args.output_onnx = args.checkpoint_path.replace(".pth", ".onnx")
if not args.skip_regular or not os.path.exists(args.output_onnx):
export_to_onnx(model, args.output_onnx, args.input_size)
onnx_size = os.path.getsize(args.output_onnx) / (1024 * 1024)
print(f"Regular ONNX size: {onnx_size:.2f} MB")
else:
print(f"Using existing ONNX: {args.output_onnx}")
if args.output_quantized is None:
args.output_quantized = args.checkpoint_path.replace(".pth", "_quantized.onnx")
result = quantize_onnx_with_ort(args.output_onnx, args.output_quantized)
if result:
quantized_size = os.path.getsize(args.output_quantized) / (1024 * 1024)
onnx_size = os.path.getsize(args.output_onnx) / (1024 * 1024)
print(f"\nQuantized ONNX size: {quantized_size:.2f} MB")
print(f"Size reduction: {quantized_size/onnx_size*100:.1f}% of original")
print(f"\n[OK] Done! Quantized ONNX saved: {args.output_quantized}")
else:
print(
"\n[WARNING] Quantization failed. Regular ONNX is available at:",
args.output_onnx,
)
print(
"For regular ONNX export only, use: python scripts/export_onnx.py <checkpoint>"
)