""" ONNX export utilities for model deployment. ONNX (Open Neural Network Exchange) is a universal format that allows models to run on different frameworks and platforms: - TensorFlow, PyTorch, etc. - Mobile devices (iOS, Android) - Web browsers (ONNX.js) - C++, Java, and other languages - Optimized inference servers """ import torch import numpy as np from pathlib import Path from typing import Tuple, Optional from .config import CHECKPOINT_PATH, MODEL_DIR, IMAGE_SIZE from .model import create_model, get_device def export_to_onnx( checkpoint_path: Path = CHECKPOINT_PATH, output_path: Optional[Path] = None, opset_version: int = 18 ) -> Path: """ Export PyTorch model to ONNX format. Args: checkpoint_path: Path to the PyTorch checkpoint output_path: Path for the ONNX model (default: models/best_model.onnx) opset_version: ONNX opset version (14 is widely compatible) Returns: Path to the exported ONNX model """ if output_path is None: output_path = MODEL_DIR / "best_model.onnx" # Load model device = torch.device("cpu") # Export on CPU for compatibility model = create_model(pretrained=False, freeze_backbone=False, device=device) checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() # Create dummy input (batch_size=1, channels=3, height=224, width=224) dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE) # Export to ONNX torch.onnx.export( model, dummy_input, output_path, export_params=True, opset_version=opset_version, do_constant_folding=True, # Optimize constants input_names=['image'], output_names=['logits'], dynamic_axes={ 'image': {0: 'batch_size'}, # Variable batch size 'logits': {0: 'batch_size'} } ) print(f"Model exported to: {output_path}") print(f"File size: {output_path.stat().st_size / 1024 / 1024:.2f} MB") return output_path def validate_onnx_model( onnx_path: Path, checkpoint_path: Path = CHECKPOINT_PATH, rtol: float = 1e-3, atol: float = 1e-5 ) -> bool: """ Validate that ONNX model produces same outputs as PyTorch model. Args: onnx_path: Path to ONNX model checkpoint_path: Path to PyTorch checkpoint rtol: Relative tolerance for comparison atol: Absolute tolerance for comparison Returns: True if outputs match, False otherwise """ import onnx import onnxruntime as ort # Check ONNX model is valid onnx_model = onnx.load(onnx_path) onnx.checker.check_model(onnx_model) print("ONNX model structure is valid") # Load PyTorch model device = torch.device("cpu") model = create_model(pretrained=False, freeze_backbone=False, device=device) checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() # Create test input test_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE) # Get PyTorch output with torch.no_grad(): pytorch_output = model(test_input).numpy() # Get ONNX output ort_session = ort.InferenceSession(str(onnx_path)) onnx_output = ort_session.run( None, {'image': test_input.numpy()} )[0] # Compare outputs is_close = np.allclose(pytorch_output, onnx_output, rtol=rtol, atol=atol) if is_close: print("Validation PASSED: ONNX outputs match PyTorch outputs") print(f" PyTorch output: {pytorch_output.flatten()[:5]}...") print(f" ONNX output: {onnx_output.flatten()[:5]}...") else: print("Validation FAILED: Outputs do not match!") print(f" Max difference: {np.max(np.abs(pytorch_output - onnx_output))}") return is_close def predict_with_onnx( onnx_path: Path, image_tensor: np.ndarray ) -> Tuple[str, float]: """ Run inference using ONNX Runtime. Args: onnx_path: Path to ONNX model image_tensor: Preprocessed image as numpy array (1, 3, 224, 224) Returns: Tuple of (predicted_class, confidence) """ import onnxruntime as ort from .config import CLASS_NAMES # Create session ort_session = ort.InferenceSession(str(onnx_path)) # Run inference logits = ort_session.run( None, {'image': image_tensor.astype(np.float32)} )[0] # Apply sigmoid and get prediction prob = 1 / (1 + np.exp(-logits[0, 0])) # Sigmoid pred_class = CLASS_NAMES[1] if prob > 0.5 else CLASS_NAMES[0] confidence = float(prob if prob > 0.5 else 1 - prob) return pred_class, confidence if __name__ == "__main__": # Export model print("=" * 50) print("EXPORTING MODEL TO ONNX") print("=" * 50) onnx_path = export_to_onnx() print("\n" + "=" * 50) print("VALIDATING ONNX MODEL") print("=" * 50) validate_onnx_model(onnx_path) print("\n" + "=" * 50) print("TESTING ONNX INFERENCE") print("=" * 50) # Test with random input test_input = np.random.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE).astype(np.float32) pred_class, confidence = predict_with_onnx(onnx_path, test_input) print(f"Test prediction: {pred_class} ({confidence:.1%})")