| """
|
| Model export utilities (ONNX, quantization)
|
| """
|
|
|
| import torch
|
| import torch.onnx
|
| from pathlib import Path
|
| from typing import Tuple
|
|
|
|
|
| def export_to_onnx(model: torch.nn.Module,
|
| save_path: str,
|
| input_size: Tuple[int, int] = (384, 384),
|
| opset_version: int = 14):
|
| """
|
| Export model to ONNX format
|
|
|
| Args:
|
| model: PyTorch model
|
| save_path: Path to save ONNX model
|
| input_size: Input image size (H, W)
|
| opset_version: ONNX opset version
|
| """
|
| model.eval()
|
| device = next(model.parameters()).device
|
|
|
|
|
| dummy_input = torch.randn(1, 3, input_size[0], input_size[1]).to(device)
|
|
|
|
|
| torch.onnx.export(
|
| model,
|
| dummy_input,
|
| save_path,
|
| export_params=True,
|
| opset_version=opset_version,
|
| do_constant_folding=True,
|
| input_names=['input'],
|
| output_names=['output', 'features'],
|
| dynamic_axes={
|
| 'input': {0: 'batch_size'},
|
| 'output': {0: 'batch_size'}
|
| }
|
| )
|
|
|
| print(f"Model exported to ONNX: {save_path}")
|
|
|
|
|
| try:
|
| import onnx
|
| onnx_model = onnx.load(save_path)
|
| onnx.checker.check_model(onnx_model)
|
| print("ONNX model verified successfully")
|
| except ImportError:
|
| print("onnx package not installed, skipping verification")
|
|
|
|
|
| def export_to_torchscript(model: torch.nn.Module,
|
| save_path: str,
|
| input_size: Tuple[int, int] = (384, 384)):
|
| """
|
| Export model to TorchScript format
|
|
|
| Args:
|
| model: PyTorch model
|
| save_path: Path to save model
|
| input_size: Input image size (H, W)
|
| """
|
| model.eval()
|
| device = next(model.parameters()).device
|
|
|
|
|
| dummy_input = torch.randn(1, 3, input_size[0], input_size[1]).to(device)
|
|
|
|
|
| scripted_model = torch.jit.trace(model, dummy_input)
|
|
|
|
|
| scripted_model.save(save_path)
|
|
|
| print(f"Model exported to TorchScript: {save_path}")
|
|
|
|
|
| def quantize_model(model: torch.nn.Module,
|
| save_path: str,
|
| quantization_type: str = 'dynamic'):
|
| """
|
| Quantize model for faster inference
|
|
|
| Args:
|
| model: PyTorch model
|
| save_path: Path to save quantized model
|
| quantization_type: 'dynamic' or 'static'
|
| """
|
| model.eval()
|
| model = model.cpu()
|
|
|
| if quantization_type == 'dynamic':
|
|
|
| quantized_model = torch.quantization.quantize_dynamic(
|
| model,
|
| {torch.nn.Linear, torch.nn.Conv2d},
|
| dtype=torch.qint8
|
| )
|
| else:
|
| raise ValueError(f"Unsupported quantization type: {quantization_type}")
|
|
|
|
|
| torch.save(quantized_model.state_dict(), save_path)
|
|
|
| print(f"Quantized model saved to: {save_path}")
|
|
|
| return quantized_model
|
|
|