"""Export and quantize PyTorch checkpoint to INT8 ONNX.""" import torch import onnx import onnxsim from collections import OrderedDict import os import sys import argparse sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) from src.minifasv2.model import MultiFTNet from src.minifasv2.config import get_kernel def load_model_from_checkpoint(checkpoint_path, device, input_size=128): checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) if "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] elif "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] else: state_dict = checkpoint kernel_size = get_kernel(input_size, input_size) model = MultiFTNet( num_channels=3, num_classes=2, embedding_size=128, conv6_kernel=kernel_size, ).to(device) new_state_dict = OrderedDict() for key, value in state_dict.items(): new_key = key if new_key.startswith("module."): new_key = new_key[7:] new_key = new_key.replace("model.prob", "model.logits") new_key = new_key.replace(".prob", ".logits") new_key = new_key.replace("model.drop", "model.dropout") new_key = new_key.replace(".drop", ".dropout") new_key = new_key.replace("FTGenerator.ft.", "FTGenerator.fourier_transform.") new_key = new_key.replace("FTGenerator.ft", "FTGenerator.fourier_transform") new_state_dict[new_key] = value model.load_state_dict(new_state_dict, strict=False) return model def export_to_onnx(model, output_path, input_size=128): print("Exporting model to ONNX...") print(f"Output path: {output_path}") model.eval() dummy_input = torch.randn(1, 3, input_size, input_size) torch.onnx.export( model, dummy_input, output_path, input_names=["input"], output_names=["output"], export_params=True, opset_version=13, do_constant_folding=True, dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, ) onnx_model = onnx.load(output_path) print("Simplifying ONNX model...") onnx_model, check = onnxsim.simplify(onnx_model) assert check, "Simplified ONNX model could not be validated" onnx.save(onnx_model, output_path) print("[OK] ONNX model exported") return output_path def quantize_onnx_with_ort(onnx_path, output_path): try: from onnxruntime.quantization import quantize_dynamic, QuantType print("\nQuantizing ONNX model with ONNX Runtime...") print(f"Input: {onnx_path}") print(f"Output: {output_path}") quantize_dynamic( model_input=onnx_path, model_output=output_path, weight_type=QuantType.QUInt8, ) print("[OK] Quantized ONNX model created") return output_path except ImportError: print( "[ERROR] onnxruntime not installed. Install with: pip install onnxruntime" ) return None except Exception as e: print(f"[ERROR] Quantization failed: {e}") return None if __name__ == "__main__": parser = argparse.ArgumentParser( description="Export model to ONNX and quantize it using ONNX Runtime" ) parser.add_argument("checkpoint_path", type=str, help="Path to .pth checkpoint") parser.add_argument( "--input_size", type=int, default=128, help="Input image size (default: 128)" ) parser.add_argument( "--output_onnx", type=str, default=None, help="Path to save regular .onnx (default: replaces .pth with .onnx)", ) parser.add_argument( "--output_quantized", type=str, default=None, help="Path to save quantized .onnx (default: adds _quantized suffix)", ) parser.add_argument( "--skip_regular", action="store_true", help="Skip exporting regular ONNX if it already exists", ) args = parser.parse_args() assert os.path.isfile( args.checkpoint_path ), f"Checkpoint not found: {args.checkpoint_path}" device = "cpu" print(f"Using device: {device}") print(f"\nLoading model from {args.checkpoint_path}...") model = load_model_from_checkpoint(args.checkpoint_path, device, args.input_size) print("[OK] Model loaded") if args.output_onnx is None: args.output_onnx = args.checkpoint_path.replace(".pth", ".onnx") if not args.skip_regular or not os.path.exists(args.output_onnx): export_to_onnx(model, args.output_onnx, args.input_size) onnx_size = os.path.getsize(args.output_onnx) / (1024 * 1024) print(f"Regular ONNX size: {onnx_size:.2f} MB") else: print(f"Using existing ONNX: {args.output_onnx}") if args.output_quantized is None: args.output_quantized = args.checkpoint_path.replace(".pth", "_quantized.onnx") result = quantize_onnx_with_ort(args.output_onnx, args.output_quantized) if result: quantized_size = os.path.getsize(args.output_quantized) / (1024 * 1024) onnx_size = os.path.getsize(args.output_onnx) / (1024 * 1024) print(f"\nQuantized ONNX size: {quantized_size:.2f} MB") print(f"Size reduction: {quantized_size/onnx_size*100:.1f}% of original") print(f"\n[OK] Done! Quantized ONNX saved: {args.output_quantized}") else: print( "\n[WARNING] Quantization failed. Regular ONNX is available at:", args.output_onnx, ) print( "For regular ONNX export only, use: python scripts/export_onnx.py " )