Spaces:
Sleeping
Sleeping
File size: 3,127 Bytes
ff0e79e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | """
Model export utilities (ONNX, quantization)
"""
import torch
import torch.onnx
from pathlib import Path
from typing import Tuple
def export_to_onnx(model: torch.nn.Module,
save_path: str,
input_size: Tuple[int, int] = (384, 384),
opset_version: int = 14):
"""
Export model to ONNX format
Args:
model: PyTorch model
save_path: Path to save ONNX model
input_size: Input image size (H, W)
opset_version: ONNX opset version
"""
model.eval()
device = next(model.parameters()).device
# Create dummy input
dummy_input = torch.randn(1, 3, input_size[0], input_size[1]).to(device)
# Export
torch.onnx.export(
model,
dummy_input,
save_path,
export_params=True,
opset_version=opset_version,
do_constant_folding=True,
input_names=['input'],
output_names=['output', 'features'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
print(f"Model exported to ONNX: {save_path}")
# Verify
try:
import onnx
onnx_model = onnx.load(save_path)
onnx.checker.check_model(onnx_model)
print("ONNX model verified successfully")
except ImportError:
print("onnx package not installed, skipping verification")
def export_to_torchscript(model: torch.nn.Module,
save_path: str,
input_size: Tuple[int, int] = (384, 384)):
"""
Export model to TorchScript format
Args:
model: PyTorch model
save_path: Path to save model
input_size: Input image size (H, W)
"""
model.eval()
device = next(model.parameters()).device
# Create dummy input
dummy_input = torch.randn(1, 3, input_size[0], input_size[1]).to(device)
# Script the model
scripted_model = torch.jit.trace(model, dummy_input)
# Save
scripted_model.save(save_path)
print(f"Model exported to TorchScript: {save_path}")
def quantize_model(model: torch.nn.Module,
save_path: str,
quantization_type: str = 'dynamic'):
"""
Quantize model for faster inference
Args:
model: PyTorch model
save_path: Path to save quantized model
quantization_type: 'dynamic' or 'static'
"""
model.eval()
model = model.cpu()
if quantization_type == 'dynamic':
# Dynamic quantization (easiest)
quantized_model = torch.quantization.quantize_dynamic(
model,
{torch.nn.Linear, torch.nn.Conv2d},
dtype=torch.qint8
)
else:
raise ValueError(f"Unsupported quantization type: {quantization_type}")
# Save
torch.save(quantized_model.state_dict(), save_path)
print(f"Quantized model saved to: {save_path}")
return quantized_model
|