File size: 5,393 Bytes
af59988
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""
ONNX export utilities for model deployment.

ONNX (Open Neural Network Exchange) is a universal format that allows
models to run on different frameworks and platforms:
- TensorFlow, PyTorch, etc.
- Mobile devices (iOS, Android)
- Web browsers (ONNX.js)
- C++, Java, and other languages
- Optimized inference servers
"""

import torch
import numpy as np
from pathlib import Path
from typing import Tuple, Optional

from .config import CHECKPOINT_PATH, MODEL_DIR, IMAGE_SIZE
from .model import create_model, get_device


def export_to_onnx(
    checkpoint_path: Path = CHECKPOINT_PATH,
    output_path: Optional[Path] = None,
    opset_version: int = 18
) -> Path:
    """
    Export PyTorch model to ONNX format.

    Args:
        checkpoint_path: Path to the PyTorch checkpoint
        output_path: Path for the ONNX model (default: models/best_model.onnx)
        opset_version: ONNX opset version (14 is widely compatible)

    Returns:
        Path to the exported ONNX model
    """
    if output_path is None:
        output_path = MODEL_DIR / "best_model.onnx"

    # Load model
    device = torch.device("cpu")  # Export on CPU for compatibility
    model = create_model(pretrained=False, freeze_backbone=False, device=device)

    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    # Create dummy input (batch_size=1, channels=3, height=224, width=224)
    dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE)

    # Export to ONNX
    torch.onnx.export(
        model,
        dummy_input,
        output_path,
        export_params=True,
        opset_version=opset_version,
        do_constant_folding=True,  # Optimize constants
        input_names=['image'],
        output_names=['logits'],
        dynamic_axes={
            'image': {0: 'batch_size'},   # Variable batch size
            'logits': {0: 'batch_size'}
        }
    )

    print(f"Model exported to: {output_path}")
    print(f"File size: {output_path.stat().st_size / 1024 / 1024:.2f} MB")

    return output_path


def validate_onnx_model(
    onnx_path: Path,
    checkpoint_path: Path = CHECKPOINT_PATH,
    rtol: float = 1e-3,
    atol: float = 1e-5
) -> bool:
    """
    Validate that ONNX model produces same outputs as PyTorch model.

    Args:
        onnx_path: Path to ONNX model
        checkpoint_path: Path to PyTorch checkpoint
        rtol: Relative tolerance for comparison
        atol: Absolute tolerance for comparison

    Returns:
        True if outputs match, False otherwise
    """
    import onnx
    import onnxruntime as ort

    # Check ONNX model is valid
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    print("ONNX model structure is valid")

    # Load PyTorch model
    device = torch.device("cpu")
    model = create_model(pretrained=False, freeze_backbone=False, device=device)
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    # Create test input
    test_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE)

    # Get PyTorch output
    with torch.no_grad():
        pytorch_output = model(test_input).numpy()

    # Get ONNX output
    ort_session = ort.InferenceSession(str(onnx_path))
    onnx_output = ort_session.run(
        None,
        {'image': test_input.numpy()}
    )[0]

    # Compare outputs
    is_close = np.allclose(pytorch_output, onnx_output, rtol=rtol, atol=atol)

    if is_close:
        print("Validation PASSED: ONNX outputs match PyTorch outputs")
        print(f"  PyTorch output: {pytorch_output.flatten()[:5]}...")
        print(f"  ONNX output:    {onnx_output.flatten()[:5]}...")
    else:
        print("Validation FAILED: Outputs do not match!")
        print(f"  Max difference: {np.max(np.abs(pytorch_output - onnx_output))}")

    return is_close


def predict_with_onnx(
    onnx_path: Path,
    image_tensor: np.ndarray
) -> Tuple[str, float]:
    """
    Run inference using ONNX Runtime.

    Args:
        onnx_path: Path to ONNX model
        image_tensor: Preprocessed image as numpy array (1, 3, 224, 224)

    Returns:
        Tuple of (predicted_class, confidence)
    """
    import onnxruntime as ort
    from .config import CLASS_NAMES

    # Create session
    ort_session = ort.InferenceSession(str(onnx_path))

    # Run inference
    logits = ort_session.run(
        None,
        {'image': image_tensor.astype(np.float32)}
    )[0]

    # Apply sigmoid and get prediction
    prob = 1 / (1 + np.exp(-logits[0, 0]))  # Sigmoid
    pred_class = CLASS_NAMES[1] if prob > 0.5 else CLASS_NAMES[0]
    confidence = float(prob if prob > 0.5 else 1 - prob)

    return pred_class, confidence


if __name__ == "__main__":
    # Export model
    print("=" * 50)
    print("EXPORTING MODEL TO ONNX")
    print("=" * 50)

    onnx_path = export_to_onnx()

    print("\n" + "=" * 50)
    print("VALIDATING ONNX MODEL")
    print("=" * 50)

    validate_onnx_model(onnx_path)

    print("\n" + "=" * 50)
    print("TESTING ONNX INFERENCE")
    print("=" * 50)

    # Test with random input
    test_input = np.random.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE).astype(np.float32)
    pred_class, confidence = predict_with_onnx(onnx_path, test_input)
    print(f"Test prediction: {pred_class} ({confidence:.1%})")