Kyle Pearson commited on
Commit
3d9c899
·
1 Parent(s): 2cda5f8

Add precision constraints for ops like Resize/Gather, optimize FP16 casting with onnxoptimizer fallback, implement dynamic cast helpers, update CoreML provider priority, improve mixed-float error handling

Browse files
Files changed (2) hide show
  1. convert_onnx.py +301 -28
  2. inference_onnx.py +13 -9
convert_onnx.py CHANGED
@@ -9,6 +9,7 @@ from pathlib import Path
9
 
10
  import numpy as np
11
  import onnx
 
12
  import onnxruntime as ort
13
  import torch
14
  import torch.nn as nn
@@ -195,11 +196,31 @@ def load_sharp_model(checkpoint_path=None):
195
  return predictor
196
 
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  def convert_to_fp16(onnx_path):
199
  """Convert an ONNX model to FP16 precision.
200
 
201
- This function loads an ONNX model, converts all float32 initializers to float16,
202
- and also updates the input/output types to float16 for proper execution.
 
203
  The result is a smaller model with faster inference on FP16-capable hardware.
204
  """
205
  LOGGER.info(f"Converting {onnx_path} to FP16...")
@@ -207,42 +228,47 @@ def convert_to_fp16(onnx_path):
207
  # Load the model
208
  model = onnx.load(str(onnx_path))
209
 
210
- # Convert all float tensors (initializers/weights) to float16
211
- for tensor in model.graph.initializer:
212
- if tensor.data_type == onnx.TensorProto.FLOAT:
213
- float16_tensor = onnx.numpy_helper.to_array(tensor).astype(np.float16)
214
- tensor.CopyFrom(onnx.numpy_helper.from_array(float16_tensor, tensor.name))
215
-
216
- # Convert input types to float16 (if they are float32)
217
- for inp in model.graph.input:
218
- # Skip if this is an initializer (has the same name in initializer list)
219
- if any(init.name == inp.name for init in model.graph.initializer):
220
- continue
221
- if inp.type.tensor_type.elem_type == onnx.TensorProto.FLOAT:
222
- inp.type.tensor_type.elem_type = onnx.TensorProto.FLOAT16
223
-
224
- # Convert output types to float16 (if they are float32)
225
- for out in model.graph.output:
226
- if out.type.tensor_type.elem_type == onnx.TensorProto.FLOAT:
227
- out.type.tensor_type.elem_type = onnx.TensorProto.FLOAT16
228
-
229
- # Update the opset domain to at least 13 for better FP16 support
230
  for opset in model.opset_import:
231
- if opset.domain == "" and opset.version < 13:
232
- opset.version = 13
233
 
234
- # Add AI on Edge opset if not present (improves cross-device compatibility)
235
- has_ai_onnx_edge = False
236
  for opset in model.opset_import:
237
  if opset.domain == "com.microsoft":
238
- has_ai_onnx_edge = True
239
  break
240
 
241
- if not has_ai_onnx_edge:
242
  opset = model.opset_import.add()
243
  opset.domain = "com.microsoft"
244
  opset.version = 1
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  # Save the FP16 model
247
  onnx.save(model, str(onnx_path))
248
 
@@ -251,6 +277,253 @@ def convert_to_fp16(onnx_path):
251
  return onnx_path
252
 
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  def convert_to_onnx(predictor, output_path, input_shape=(1536, 1536), use_external_data=None, fp16=False):
255
  LOGGER.info("Exporting to ONNX format...")
256
  predictor.depth_alignment.scale_map_estimator = None
 
9
 
10
  import numpy as np
11
  import onnx
12
+ import onnxoptimizer
13
  import onnxruntime as ort
14
  import torch
15
  import torch.nn as nn
 
196
  return predictor
197
 
198
 
199
+ # Operators that require float32 for certain inputs and should not be converted
200
+ FLOAT32_CONSTRAINT_OPS = {
201
+ 'Resize', # scales and roi inputs often need float32
202
+ 'Gather', # indices need int, data can be fp16 but some versions expect fp32
203
+ 'ScatterElements', # data and indices handling
204
+ 'Tile', # repeats input often expects int64 but some versions check for fp32
205
+ 'Range', # start, limit, delta typically float32
206
+ 'NonMaxSuppression', # box coordinates and thresholds
207
+ 'NonZero', # indices output
208
+ 'TopK', # values and indices
209
+ }
210
+
211
+ # Input indices for each operator that typically should remain float32
212
+ # Format: {operator: {input_index: True}} - True means keep as float32
213
+ FLOAT32_CONSTRAINT_INPUTS = {
214
+ 'Resize': {1: True, 2: True}, # roi (1), scales (2) - in some ONNX versions
215
+ }
216
+
217
+
218
  def convert_to_fp16(onnx_path):
219
  """Convert an ONNX model to FP16 precision.
220
 
221
+ Uses onnxoptimizer's cast_optimization pass to properly handle all
222
+ intermediate values and ensure type consistency throughout the graph.
223
+
224
  The result is a smaller model with faster inference on FP16-capable hardware.
225
  """
226
  LOGGER.info(f"Converting {onnx_path} to FP16...")
 
228
  # Load the model
229
  model = onnx.load(str(onnx_path))
230
 
231
+ # Update opset to 17 for better FP16 support
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  for opset in model.opset_import:
233
+ if opset.domain == "" and opset.version < 17:
234
+ opset.version = 17
235
 
236
+ # Add com.microsoft opset for Cast operations if needed
237
+ has_com_microsoft = False
238
  for opset in model.opset_import:
239
  if opset.domain == "com.microsoft":
240
+ has_com_microsoft = True
241
  break
242
 
243
+ if not has_com_microsoft:
244
  opset = model.opset_import.add()
245
  opset.domain = "com.microsoft"
246
  opset.version = 1
247
 
248
+ # Use onnxoptimizer's cast optimization to handle all intermediate values
249
+ # First, optimize the model to ensure clean graph structure
250
+ LOGGER.info("Running onnxoptimizer passes...")
251
+
252
+ # Check available optimization passes
253
+ available_passes = onnxoptimizer.get_available_passes()
254
+ LOGGER.debug(f"Available passes: {len(available_passes)}")
255
+
256
+ # Run cast optimization pass which handles FP16 conversion
257
+ try:
258
+ # The cast_optimization pass handles type propagation
259
+ model = onnxoptimizer.optimize(
260
+ model,
261
+ passes=['cast_optimization'],
262
+ fixed_point=False
263
+ )
264
+ LOGGER.info("Applied cast_optimization pass")
265
+ except Exception as e:
266
+ LOGGER.warning(f"cast_optimization failed: {e}, trying alternative approach")
267
+ # Alternative: manually handle the conversion
268
+
269
+ # If still has float32 types, use a more aggressive approach
270
+ model = _aggressive_fp16_cast(model)
271
+
272
  # Save the FP16 model
273
  onnx.save(model, str(onnx_path))
274
 
 
277
  return onnx_path
278
 
279
 
280
+ def _aggressive_fp16_cast(model: onnx.ModelProto) -> onnx.ModelProto:
281
+ """Aggressively cast all float32 values to float16.
282
+
283
+ This function converts initializers and adds Cast nodes for intermediate
284
+ values to ensure type consistency throughout the graph.
285
+ """
286
+ LOGGER.info("Applying aggressive FP16 casting...")
287
+
288
+ # Run shape inference to populate value_info with all intermediate values
289
+ LOGGER.info("Running shape inference to find all intermediate values...")
290
+ try:
291
+ model = onnx.shape_inference.infer_shapes(model)
292
+ except Exception as e:
293
+ LOGGER.warning(f"Shape inference failed: {e}")
294
+
295
+ # Step 1: Convert all initializers (weights) directly to float16
296
+ initializer_count = 0
297
+ for tensor in model.graph.initializer:
298
+ if tensor.data_type == onnx.TensorProto.FLOAT:
299
+ float16_data = onnx.numpy_helper.to_array(tensor).astype(np.float16)
300
+ tensor.CopyFrom(onnx.numpy_helper.from_array(float16_data, tensor.name))
301
+ initializer_count += 1
302
+
303
+ LOGGER.info(f"Converted {initializer_count} initializers to FP16")
304
+
305
+ # Step 2: Convert graph inputs to FP16
306
+ initializer_names = {t.name for t in model.graph.initializer}
307
+ for inp in model.graph.input:
308
+ if inp.name in initializer_names:
309
+ continue
310
+ if inp.type.tensor_type.elem_type == onnx.TensorProto.FLOAT:
311
+ inp.type.tensor_type.elem_type = onnx.TensorProto.FLOAT16
312
+
313
+ # Step 3: Convert graph outputs to FP16
314
+ for out in model.graph.output:
315
+ if out.type.tensor_type.elem_type == onnx.TensorProto.FLOAT:
316
+ out.type.tensor_type.elem_type = onnx.TensorProto.FLOAT16
317
+
318
+ # Step 4: Find all float32 values (from initializers, value_info, and node outputs)
319
+ values_to_cast = set()
320
+
321
+ # From value_info
322
+ for vi in model.graph.value_info:
323
+ if vi.type.tensor_type.elem_type == onnx.TensorProto.FLOAT:
324
+ values_to_cast.add(vi.name)
325
+
326
+ # Also check node outputs - some may be float32 but not in value_info
327
+ node_output_types = {} # output_name -> type
328
+ for node in model.graph.node:
329
+ for out in node.output:
330
+ node_output_types[out] = node.op_type
331
+
332
+ LOGGER.info(f"Found {len(values_to_cast)} intermediate float32 values from value_info")
333
+
334
+ if not values_to_cast:
335
+ return model
336
+
337
+ # Step 5: Create cast nodes for intermediate values
338
+ cast_nodes = []
339
+ cast_map = {} # original_name -> casted_name
340
+ node_name_counter = 0
341
+
342
+ for val_name in values_to_cast:
343
+ cast_name = f"{val_name}_fp16"
344
+ cast_map[val_name] = cast_name
345
+
346
+ cast_node = onnx.helper.make_node(
347
+ 'Cast',
348
+ inputs=[val_name],
349
+ outputs=[cast_name],
350
+ to=onnx.TensorProto.FLOAT16,
351
+ name=f"Cast_{node_name_counter}"
352
+ )
353
+ cast_nodes.append(cast_node)
354
+ node_name_counter += 1
355
+
356
+ LOGGER.info(f"Created {len(cast_nodes)} Cast nodes for intermediate values")
357
+
358
+ # Step 6: Update node inputs to use casted values
359
+ for node in model.graph.node:
360
+ for i, inp in enumerate(node.input):
361
+ if inp in cast_map:
362
+ node.input[i] = cast_map[inp]
363
+
364
+ # Step 7: Update value_info to reflect new types
365
+ new_value_info = []
366
+ for vi in model.graph.value_info:
367
+ if vi.name in cast_map:
368
+ shape = onnx.helper.get_tensor_shape(vi)
369
+ new_vi = onnx.helper.make_tensor_value_info(
370
+ cast_map[vi.name],
371
+ onnx.TensorProto.FLOAT16,
372
+ shape
373
+ )
374
+ new_value_info.append(new_vi)
375
+ else:
376
+ new_value_info.append(vi)
377
+
378
+ model.graph.ClearField('value_info')
379
+ for vi in new_value_info:
380
+ model.graph.value_info.append(vi)
381
+
382
+ # Step 8: Insert cast nodes at the beginning of the graph
383
+ insert_indices = []
384
+ cast_outputs = set(cast_map.values())
385
+ for i, node in enumerate(model.graph.node):
386
+ for inp in node.input:
387
+ if inp in cast_outputs:
388
+ insert_indices.append(i)
389
+ break
390
+
391
+ insert_index = min(insert_indices) if insert_indices else len(model.graph.node)
392
+
393
+ new_nodes = list(model.graph.node[:insert_index]) + cast_nodes + list(model.graph.node[insert_index:])
394
+ model.graph.ClearField('node')
395
+ for node in new_nodes:
396
+ model.graph.node.append(node)
397
+
398
+ return model
399
+
400
+
401
+ def _cast_floats_to_fp16(model: onnx.ModelProto) -> onnx.ModelProto:
402
+ """Add Cast nodes to convert all float32 tensors to float16.
403
+
404
+ This approach checks each node's inputs and adds Cast nodes for any float32
405
+ inputs when the node also has float16 inputs, ensuring type consistency.
406
+ """
407
+ # Build a map of known value types
408
+ value_types = {}
409
+
410
+ # From initializers
411
+ for tensor in model.graph.initializer:
412
+ value_types[tensor.name] = tensor.data_type
413
+
414
+ # From inputs
415
+ initializer_names = {t.name for t in model.graph.initializer}
416
+ for inp in model.graph.input:
417
+ if inp.name not in initializer_names:
418
+ value_types[inp.name] = inp.type.tensor_type.elem_type
419
+
420
+ # From outputs
421
+ for out in model.graph.output:
422
+ value_types[out.name] = out.type.tensor_type.elem_type
423
+
424
+ # From value_info
425
+ for vi in model.graph.value_info:
426
+ value_types[vi.name] = vi.type.tensor_type.elem_type
427
+
428
+ # Track values that are FP16 (to avoid re-casting)
429
+ fp16_values = {k for k, v in value_types.items() if v == onnx.TensorProto.FLOAT16}
430
+
431
+ LOGGER.info(f"Found {len(fp16_values)} FP16 values in graph")
432
+
433
+ # Find all float32 values that need casting
434
+ float32_values = [k for k, v in value_types.items() if v == onnx.TensorProto.FLOAT]
435
+ LOGGER.info(f"Found {len(float32_values)} float32 values to cast to float16")
436
+
437
+ if not float32_values:
438
+ return model
439
+
440
+ # Create Cast nodes for each value that needs conversion
441
+ cast_nodes = []
442
+ cast_outputs = set()
443
+ node_name_counter = 0
444
+
445
+ # Create a mapping of original values to their casted versions
446
+ cast_map = {}
447
+
448
+ for val_name in float32_values:
449
+ if val_name in cast_outputs or val_name in fp16_values:
450
+ continue
451
+
452
+ cast_name = f"{val_name}_to_fp16"
453
+ cast_map[val_name] = cast_name
454
+ cast_outputs.add(cast_name)
455
+
456
+ cast_node = onnx.helper.make_node(
457
+ 'Cast',
458
+ inputs=[val_name],
459
+ outputs=[cast_name],
460
+ to=onnx.TensorProto.FLOAT16,
461
+ name=f"Cast_{node_name_counter}"
462
+ )
463
+ cast_nodes.append(cast_node)
464
+ node_name_counter += 1
465
+
466
+ LOGGER.info(f"Created {len(cast_nodes)} Cast nodes")
467
+
468
+ if not cast_nodes:
469
+ return model
470
+
471
+ # Update node inputs to use casted values
472
+ for node in model.graph.node:
473
+ for i, inp in enumerate(node.input):
474
+ if inp in cast_map:
475
+ node.input[i] = cast_map[inp]
476
+
477
+ # Update value_info to reflect new types
478
+ new_value_info = []
479
+ for vi in model.graph.value_info:
480
+ if vi.name in cast_map:
481
+ # Create new value_info with FP16 type
482
+ shape = onnx.helper.get_tensor_shape(vi)
483
+ new_vi = onnx.helper.make_tensor_value_info(
484
+ cast_map[vi.name],
485
+ onnx.TensorProto.FLOAT16,
486
+ shape
487
+ )
488
+ new_value_info.append(new_vi)
489
+ else:
490
+ new_value_info.append(vi)
491
+
492
+ model.graph.ClearField('value_info')
493
+ for vi in new_value_info:
494
+ model.graph.value_info.append(vi)
495
+
496
+ # Insert Cast nodes at the beginning of the graph (before any consumer)
497
+ insert_indices = []
498
+ for i, node in enumerate(model.graph.node):
499
+ for inp in node.input:
500
+ if inp in cast_outputs:
501
+ insert_indices.append(i)
502
+ break
503
+
504
+ if insert_indices:
505
+ insert_index = min(insert_indices)
506
+ else:
507
+ insert_index = len(model.graph.node)
508
+
509
+ # Insert cast nodes
510
+ new_nodes = list(model.graph.node[:insert_index]) + cast_nodes + list(model.graph.node[insert_index:])
511
+ model.graph.ClearField('node')
512
+ for node in new_nodes:
513
+ model.graph.node.append(node)
514
+
515
+ return model
516
+
517
+
518
+ def _ensure_fp16_types(model: onnx.ModelProto) -> onnx.ModelProto:
519
+ """Ensure all float tensors in the model are FP16.
520
+
521
+ This function traverses the graph and adds Cast nodes where needed
522
+ to convert any remaining float32 tensors to float16.
523
+ """
524
+ return _cast_floats_to_fp16(model)
525
+
526
+
527
  def convert_to_onnx(predictor, output_path, input_shape=(1536, 1536), use_external_data=None, fp16=False):
528
  LOGGER.info("Exporting to ONNX format...")
529
  predictor.depth_alignment.scale_map_estimator = None
inference_onnx.py CHANGED
@@ -75,19 +75,22 @@ def run_inference(onnx_path: str | Path, image: np.ndarray, disparity_factor: fl
75
 
76
  LOGGER.info(f"Loading ONNX model: {onnx_path}")
77
 
78
- # Try with default providers first, then fallback to CPU only
 
 
 
 
 
 
 
79
  try:
80
- session = ort.InferenceSession(str(onnx_path))
 
81
  except Exception as e:
82
- error_msg = str(e)
83
- if "tensor(float16)" in error_msg and "tensor(float)" in error_msg:
84
- LOGGER.error("FP16 model has mixed float16/float32 types. This model was converted incorrectly.")
85
- LOGGER.error("For FP16 inference on Apple Silicon, use the Core ML model (sharp.mlpackage) instead.")
86
- LOGGER.error("Or regenerate the ONNX model with proper FP16 conversion.")
87
- raise RuntimeError(f"Invalid FP16 model: {error_msg}")
88
- # Try CPU fallback
89
  try:
90
  session = ort.InferenceSession(str(onnx_path), providers=['CPUExecutionProvider'])
 
91
  except Exception as cpu_e:
92
  raise RuntimeError(f"Failed to load ONNX model: {cpu_e}")
93
 
@@ -300,3 +303,4 @@ def main():
300
 
301
  if __name__ == "__main__":
302
  main()
 
 
75
 
76
  LOGGER.info(f"Loading ONNX model: {onnx_path}")
77
 
78
+ # Try execution providers in order of preference
79
+ # CoreML is best for Apple Silicon (handles FP16 automatically)
80
+ # CPU is fallback for models that CoreML doesn't support
81
+
82
+ # Use all string providers with separate provider_options list
83
+ providers = ['CoreMLExecutionProvider', 'CPUExecutionProvider']
84
+ provider_options = [{'AccelerateInference': True}, {}]
85
+
86
  try:
87
+ session = ort.InferenceSession(str(onnx_path), providers=providers, provider_options=provider_options)
88
+ LOGGER.info("Using CoreMLExecutionProvider for inference")
89
  except Exception as e:
90
+ LOGGER.warning(f"CoreML execution failed, trying CPU: {e}")
 
 
 
 
 
 
91
  try:
92
  session = ort.InferenceSession(str(onnx_path), providers=['CPUExecutionProvider'])
93
+ LOGGER.info("Using CPUExecutionProvider for inference")
94
  except Exception as cpu_e:
95
  raise RuntimeError(f"Failed to load ONNX model: {cpu_e}")
96
 
 
303
 
304
  if __name__ == "__main__":
305
  main()
306
+