Spaces:
Sleeping
Sleeping
File size: 5,393 Bytes
af59988 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
"""
ONNX export utilities for model deployment.
ONNX (Open Neural Network Exchange) is a universal format that allows
models to run on different frameworks and platforms:
- TensorFlow, PyTorch, etc.
- Mobile devices (iOS, Android)
- Web browsers (ONNX.js)
- C++, Java, and other languages
- Optimized inference servers
"""
import torch
import numpy as np
from pathlib import Path
from typing import Tuple, Optional
from .config import CHECKPOINT_PATH, MODEL_DIR, IMAGE_SIZE
from .model import create_model, get_device
def export_to_onnx(
checkpoint_path: Path = CHECKPOINT_PATH,
output_path: Optional[Path] = None,
opset_version: int = 18
) -> Path:
"""
Export PyTorch model to ONNX format.
Args:
checkpoint_path: Path to the PyTorch checkpoint
output_path: Path for the ONNX model (default: models/best_model.onnx)
opset_version: ONNX opset version (14 is widely compatible)
Returns:
Path to the exported ONNX model
"""
if output_path is None:
output_path = MODEL_DIR / "best_model.onnx"
# Load model
device = torch.device("cpu") # Export on CPU for compatibility
model = create_model(pretrained=False, freeze_backbone=False, device=device)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Create dummy input (batch_size=1, channels=3, height=224, width=224)
dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE)
# Export to ONNX
torch.onnx.export(
model,
dummy_input,
output_path,
export_params=True,
opset_version=opset_version,
do_constant_folding=True, # Optimize constants
input_names=['image'],
output_names=['logits'],
dynamic_axes={
'image': {0: 'batch_size'}, # Variable batch size
'logits': {0: 'batch_size'}
}
)
print(f"Model exported to: {output_path}")
print(f"File size: {output_path.stat().st_size / 1024 / 1024:.2f} MB")
return output_path
def validate_onnx_model(
onnx_path: Path,
checkpoint_path: Path = CHECKPOINT_PATH,
rtol: float = 1e-3,
atol: float = 1e-5
) -> bool:
"""
Validate that ONNX model produces same outputs as PyTorch model.
Args:
onnx_path: Path to ONNX model
checkpoint_path: Path to PyTorch checkpoint
rtol: Relative tolerance for comparison
atol: Absolute tolerance for comparison
Returns:
True if outputs match, False otherwise
"""
import onnx
import onnxruntime as ort
# Check ONNX model is valid
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print("ONNX model structure is valid")
# Load PyTorch model
device = torch.device("cpu")
model = create_model(pretrained=False, freeze_backbone=False, device=device)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Create test input
test_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE)
# Get PyTorch output
with torch.no_grad():
pytorch_output = model(test_input).numpy()
# Get ONNX output
ort_session = ort.InferenceSession(str(onnx_path))
onnx_output = ort_session.run(
None,
{'image': test_input.numpy()}
)[0]
# Compare outputs
is_close = np.allclose(pytorch_output, onnx_output, rtol=rtol, atol=atol)
if is_close:
print("Validation PASSED: ONNX outputs match PyTorch outputs")
print(f" PyTorch output: {pytorch_output.flatten()[:5]}...")
print(f" ONNX output: {onnx_output.flatten()[:5]}...")
else:
print("Validation FAILED: Outputs do not match!")
print(f" Max difference: {np.max(np.abs(pytorch_output - onnx_output))}")
return is_close
def predict_with_onnx(
onnx_path: Path,
image_tensor: np.ndarray
) -> Tuple[str, float]:
"""
Run inference using ONNX Runtime.
Args:
onnx_path: Path to ONNX model
image_tensor: Preprocessed image as numpy array (1, 3, 224, 224)
Returns:
Tuple of (predicted_class, confidence)
"""
import onnxruntime as ort
from .config import CLASS_NAMES
# Create session
ort_session = ort.InferenceSession(str(onnx_path))
# Run inference
logits = ort_session.run(
None,
{'image': image_tensor.astype(np.float32)}
)[0]
# Apply sigmoid and get prediction
prob = 1 / (1 + np.exp(-logits[0, 0])) # Sigmoid
pred_class = CLASS_NAMES[1] if prob > 0.5 else CLASS_NAMES[0]
confidence = float(prob if prob > 0.5 else 1 - prob)
return pred_class, confidence
if __name__ == "__main__":
# Export model
print("=" * 50)
print("EXPORTING MODEL TO ONNX")
print("=" * 50)
onnx_path = export_to_onnx()
print("\n" + "=" * 50)
print("VALIDATING ONNX MODEL")
print("=" * 50)
validate_onnx_model(onnx_path)
print("\n" + "=" * 50)
print("TESTING ONNX INFERENCE")
print("=" * 50)
# Test with random input
test_input = np.random.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE).astype(np.float32)
pred_class, confidence = predict_with_onnx(onnx_path, test_input)
print(f"Test prediction: {pred_class} ({confidence:.1%})")
|