ayushm98 commited on
Commit
16cd69d
·
1 Parent(s): ad57885

feat: add ONNX export script for production inference

Browse files

- Convert PyTorch model to ONNX format
- Apply ONNX optimizations for BERT models
- Verify inference matches PyTorch outputs
- Benchmark PyTorch vs ONNX latency

Files changed (1) hide show
  1. ml/export/convert_to_onnx.py +254 -0
ml/export/convert_to_onnx.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Convert trained PyTorch model to ONNX format for fast inference."""
2
+
3
+ import json
4
+ import time
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import onnx
9
+ import onnxruntime as ort
10
+ import torch
11
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
12
+
13
+
14
+ def convert_to_onnx(
15
+ model_dir: str = "ml/artifacts/complexity-classifier",
16
+ output_path: str | None = None,
17
+ opset_version: int = 14,
18
+ optimize: bool = True,
19
+ ) -> str:
20
+ """
21
+ Convert a trained HuggingFace model to ONNX format.
22
+
23
+ Args:
24
+ model_dir: Directory containing trained model
25
+ output_path: Output path for ONNX model (defaults to model_dir/model.onnx)
26
+ opset_version: ONNX opset version
27
+ optimize: Whether to apply ONNX optimizations
28
+
29
+ Returns:
30
+ Path to the saved ONNX model
31
+ """
32
+ model_dir = Path(model_dir)
33
+ output_path = Path(output_path or model_dir / "model.onnx")
34
+ output_path.parent.mkdir(parents=True, exist_ok=True)
35
+
36
+ print(f"Converting model to ONNX")
37
+ print(f" Model dir: {model_dir}")
38
+ print(f" Output: {output_path}")
39
+ print(f" Opset: {opset_version}")
40
+
41
+ # Load model and tokenizer
42
+ print("\nLoading model...")
43
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
44
+ model = AutoModelForSequenceClassification.from_pretrained(model_dir)
45
+ model.eval()
46
+
47
+ # Create dummy input for tracing
48
+ dummy_text = "This is a sample text for tracing the model."
49
+ dummy_inputs = tokenizer(
50
+ dummy_text,
51
+ padding="max_length",
52
+ truncation=True,
53
+ max_length=128,
54
+ return_tensors="pt",
55
+ )
56
+
57
+ # Define input names and dynamic axes
58
+ input_names = ["input_ids", "attention_mask"]
59
+ output_names = ["logits"]
60
+
61
+ dynamic_axes = {
62
+ "input_ids": {0: "batch_size", 1: "sequence"},
63
+ "attention_mask": {0: "batch_size", 1: "sequence"},
64
+ "logits": {0: "batch_size"},
65
+ }
66
+
67
+ # Export to ONNX
68
+ print("\nExporting to ONNX...")
69
+ torch.onnx.export(
70
+ model,
71
+ (dummy_inputs["input_ids"], dummy_inputs["attention_mask"]),
72
+ str(output_path),
73
+ input_names=input_names,
74
+ output_names=output_names,
75
+ dynamic_axes=dynamic_axes,
76
+ opset_version=opset_version,
77
+ do_constant_folding=True,
78
+ )
79
+
80
+ print(f"Model exported to: {output_path}")
81
+
82
+ # Validate the model
83
+ print("\nValidating ONNX model...")
84
+ onnx_model = onnx.load(str(output_path))
85
+ onnx.checker.check_model(onnx_model)
86
+ print("ONNX model validation passed!")
87
+
88
+ # Apply optimizations if requested
89
+ if optimize:
90
+ print("\nApplying ONNX optimizations...")
91
+ from onnxruntime.transformers import optimizer
92
+
93
+ optimized_path = output_path.with_suffix(".optimized.onnx")
94
+ optimized_model = optimizer.optimize_model(
95
+ str(output_path),
96
+ model_type="bert",
97
+ num_heads=12,
98
+ hidden_size=768,
99
+ )
100
+ optimized_model.save_model_to_file(str(optimized_path))
101
+ print(f"Optimized model saved to: {optimized_path}")
102
+
103
+ # Use optimized model
104
+ output_path = optimized_path
105
+
106
+ # Verify inference
107
+ print("\nVerifying inference...")
108
+ _verify_onnx_inference(model, tokenizer, output_path)
109
+
110
+ # Benchmark
111
+ print("\nBenchmarking...")
112
+ pytorch_time, onnx_time = _benchmark_inference(model, tokenizer, output_path)
113
+
114
+ # Save conversion info
115
+ info = {
116
+ "original_model": str(model_dir),
117
+ "onnx_path": str(output_path),
118
+ "opset_version": opset_version,
119
+ "optimized": optimize,
120
+ "benchmark": {
121
+ "pytorch_ms": pytorch_time,
122
+ "onnx_ms": onnx_time,
123
+ "speedup": pytorch_time / onnx_time if onnx_time > 0 else 0,
124
+ },
125
+ }
126
+
127
+ info_path = output_path.with_suffix(".json")
128
+ with open(info_path, "w") as f:
129
+ json.dump(info, f, indent=2)
130
+
131
+ print("\n" + "=" * 50)
132
+ print("Conversion complete!")
133
+ print("=" * 50)
134
+ print(f"\nONNX model: {output_path}")
135
+ print(f"PyTorch latency: {pytorch_time:.2f}ms")
136
+ print(f"ONNX latency: {onnx_time:.2f}ms")
137
+ print(f"Speedup: {pytorch_time / onnx_time:.2f}x")
138
+
139
+ return str(output_path)
140
+
141
+
142
+ def _verify_onnx_inference(model, tokenizer, onnx_path: Path) -> None:
143
+ """Verify ONNX model produces same outputs as PyTorch."""
144
+ # Test inputs
145
+ test_texts = [
146
+ "Hello, how are you?",
147
+ "Write a Python function to calculate the factorial of a number recursively.",
148
+ ]
149
+
150
+ for text in test_texts:
151
+ inputs = tokenizer(
152
+ text,
153
+ padding="max_length",
154
+ truncation=True,
155
+ max_length=128,
156
+ return_tensors="pt",
157
+ )
158
+
159
+ # PyTorch inference
160
+ with torch.no_grad():
161
+ pytorch_outputs = model(**inputs)
162
+ pytorch_logits = pytorch_outputs.logits.numpy()
163
+
164
+ # ONNX inference
165
+ session = ort.InferenceSession(str(onnx_path))
166
+ onnx_inputs = {
167
+ "input_ids": inputs["input_ids"].numpy(),
168
+ "attention_mask": inputs["attention_mask"].numpy(),
169
+ }
170
+ onnx_outputs = session.run(None, onnx_inputs)
171
+ onnx_logits = onnx_outputs[0]
172
+
173
+ # Compare
174
+ np.testing.assert_allclose(pytorch_logits, onnx_logits, rtol=1e-3, atol=1e-4)
175
+
176
+ print(" Inference verification passed!")
177
+
178
+
179
+ def _benchmark_inference(
180
+ model, tokenizer, onnx_path: Path, num_runs: int = 100
181
+ ) -> tuple[float, float]:
182
+ """Benchmark PyTorch vs ONNX inference latency."""
183
+ test_text = "What is the capital of France?"
184
+ inputs = tokenizer(
185
+ test_text,
186
+ padding="max_length",
187
+ truncation=True,
188
+ max_length=128,
189
+ return_tensors="pt",
190
+ )
191
+
192
+ # Warmup
193
+ with torch.no_grad():
194
+ _ = model(**inputs)
195
+
196
+ session = ort.InferenceSession(str(onnx_path))
197
+ onnx_inputs = {
198
+ "input_ids": inputs["input_ids"].numpy(),
199
+ "attention_mask": inputs["attention_mask"].numpy(),
200
+ }
201
+ _ = session.run(None, onnx_inputs)
202
+
203
+ # Benchmark PyTorch
204
+ start = time.perf_counter()
205
+ for _ in range(num_runs):
206
+ with torch.no_grad():
207
+ _ = model(**inputs)
208
+ pytorch_time = (time.perf_counter() - start) / num_runs * 1000 # ms
209
+
210
+ # Benchmark ONNX
211
+ start = time.perf_counter()
212
+ for _ in range(num_runs):
213
+ _ = session.run(None, onnx_inputs)
214
+ onnx_time = (time.perf_counter() - start) / num_runs * 1000 # ms
215
+
216
+ return pytorch_time, onnx_time
217
+
218
+
219
+ if __name__ == "__main__":
220
+ import argparse
221
+
222
+ parser = argparse.ArgumentParser(description="Convert model to ONNX")
223
+ parser.add_argument(
224
+ "--model-dir",
225
+ type=str,
226
+ default="ml/artifacts/complexity-classifier",
227
+ help="Model directory",
228
+ )
229
+ parser.add_argument(
230
+ "--output",
231
+ type=str,
232
+ default=None,
233
+ help="Output path for ONNX model",
234
+ )
235
+ parser.add_argument(
236
+ "--opset",
237
+ type=int,
238
+ default=14,
239
+ help="ONNX opset version",
240
+ )
241
+ parser.add_argument(
242
+ "--no-optimize",
243
+ action="store_true",
244
+ help="Skip ONNX optimizations",
245
+ )
246
+
247
+ args = parser.parse_args()
248
+
249
+ convert_to_onnx(
250
+ model_dir=args.model_dir,
251
+ output_path=args.output,
252
+ opset_version=args.opset,
253
+ optimize=not args.no_optimize,
254
+ )