""" Model export utilities (ONNX, quantization) """ import torch import torch.onnx from pathlib import Path from typing import Tuple def export_to_onnx(model: torch.nn.Module, save_path: str, input_size: Tuple[int, int] = (384, 384), opset_version: int = 14): """ Export model to ONNX format Args: model: PyTorch model save_path: Path to save ONNX model input_size: Input image size (H, W) opset_version: ONNX opset version """ model.eval() device = next(model.parameters()).device # Create dummy input dummy_input = torch.randn(1, 3, input_size[0], input_size[1]).to(device) # Export torch.onnx.export( model, dummy_input, save_path, export_params=True, opset_version=opset_version, do_constant_folding=True, input_names=['input'], output_names=['output', 'features'], dynamic_axes={ 'input': {0: 'batch_size'}, 'output': {0: 'batch_size'} } ) print(f"Model exported to ONNX: {save_path}") # Verify try: import onnx onnx_model = onnx.load(save_path) onnx.checker.check_model(onnx_model) print("ONNX model verified successfully") except ImportError: print("onnx package not installed, skipping verification") def export_to_torchscript(model: torch.nn.Module, save_path: str, input_size: Tuple[int, int] = (384, 384)): """ Export model to TorchScript format Args: model: PyTorch model save_path: Path to save model input_size: Input image size (H, W) """ model.eval() device = next(model.parameters()).device # Create dummy input dummy_input = torch.randn(1, 3, input_size[0], input_size[1]).to(device) # Script the model scripted_model = torch.jit.trace(model, dummy_input) # Save scripted_model.save(save_path) print(f"Model exported to TorchScript: {save_path}") def quantize_model(model: torch.nn.Module, save_path: str, quantization_type: str = 'dynamic'): """ Quantize model for faster inference Args: model: PyTorch model save_path: Path to save quantized model quantization_type: 'dynamic' or 'static' """ model.eval() model = model.cpu() if quantization_type == 'dynamic': # Dynamic quantization (easiest) quantized_model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8 ) else: raise ValueError(f"Unsupported quantization type: {quantization_type}") # Save torch.save(quantized_model.state_dict(), save_path) print(f"Quantized model saved to: {save_path}") return quantized_model