Kyle Pearson commited on
Commit
1dd5974
·
1 Parent(s): 3d9c899

Fix precision tolerances, remove legacy FP16 logic, update data handling, standardize execution provider

Browse files
Files changed (2) hide show
  1. convert_onnx.py +36 -387
  2. inference_onnx.py +4 -19
convert_onnx.py CHANGED
@@ -9,6 +9,7 @@ from pathlib import Path
9
 
10
  import numpy as np
11
  import onnx
 
12
  import onnxoptimizer
13
  import onnxruntime as ort
14
  import torch
@@ -43,7 +44,7 @@ class ToleranceConfig:
43
  self.random_tolerances = {
44
  "mean_vectors_3d_positions": 0.001,
45
  "singular_values_scales": 0.0001,
46
- "quaternions_rotations": 10.0, # Increased for ONNX numerical precision
47
  "colors_rgb_linear": 0.002,
48
  "opacities_alpha_channel": 0.005,
49
  }
@@ -51,12 +52,12 @@ class ToleranceConfig:
51
  self.image_tolerances = {
52
  "mean_vectors_3d_positions": 3.5,
53
  "singular_values_scales": 0.035,
54
- "quaternions_rotations": 10.0, # Increased for ONNX numerical precision
55
  "colors_rgb_linear": 0.01,
56
  "opacities_alpha_channel": 0.05,
57
  }
58
  if self.angular_tolerances_random is None:
59
- self.angular_tolerances_random = {"mean": 0.01, "p99": 0.1, "p99_9": 1.0, "max": 10.0} # Increased for ONNX precision
60
  if self.angular_tolerances_image is None:
61
  self.angular_tolerances_image = {"mean": 0.2, "p99": 2.0, "p99_9": 5.0, "max": 25.0}
62
 
@@ -147,7 +148,7 @@ def cleanup_onnx_files(onnx_path):
147
  try:
148
  if onnx_path.exists():
149
  onnx_path.unlink()
150
- LOGGER.info(f"Removed {onnx_path}")
151
  except Exception as e:
152
  LOGGER.warning(f"Could not remove {onnx_path}: {e}")
153
 
@@ -156,7 +157,7 @@ def cleanup_onnx_files(onnx_path):
156
  try:
157
  if data_path.exists():
158
  data_path.unlink()
159
- LOGGER.info(f"Removed {data_path}")
160
  except Exception as e:
161
  LOGGER.warning(f"Could not remove {data_path}: {e}")
162
 
@@ -167,7 +168,7 @@ def cleanup_onnx_files(onnx_path):
167
  for f in glob.glob(pattern):
168
  try:
169
  Path(f).unlink()
170
- LOGGER.info(f"Removed temporary file {f}")
171
  except Exception:
172
  pass
173
 
@@ -196,335 +197,7 @@ def load_sharp_model(checkpoint_path=None):
196
  return predictor
197
 
198
 
199
- # Operators that require float32 for certain inputs and should not be converted
200
- FLOAT32_CONSTRAINT_OPS = {
201
- 'Resize', # scales and roi inputs often need float32
202
- 'Gather', # indices need int, data can be fp16 but some versions expect fp32
203
- 'ScatterElements', # data and indices handling
204
- 'Tile', # repeats input often expects int64 but some versions check for fp32
205
- 'Range', # start, limit, delta typically float32
206
- 'NonMaxSuppression', # box coordinates and thresholds
207
- 'NonZero', # indices output
208
- 'TopK', # values and indices
209
- }
210
-
211
- # Input indices for each operator that typically should remain float32
212
- # Format: {operator: {input_index: True}} - True means keep as float32
213
- FLOAT32_CONSTRAINT_INPUTS = {
214
- 'Resize': {1: True, 2: True}, # roi (1), scales (2) - in some ONNX versions
215
- }
216
-
217
-
218
- def convert_to_fp16(onnx_path):
219
- """Convert an ONNX model to FP16 precision.
220
-
221
- Uses onnxoptimizer's cast_optimization pass to properly handle all
222
- intermediate values and ensure type consistency throughout the graph.
223
-
224
- The result is a smaller model with faster inference on FP16-capable hardware.
225
- """
226
- LOGGER.info(f"Converting {onnx_path} to FP16...")
227
-
228
- # Load the model
229
- model = onnx.load(str(onnx_path))
230
-
231
- # Update opset to 17 for better FP16 support
232
- for opset in model.opset_import:
233
- if opset.domain == "" and opset.version < 17:
234
- opset.version = 17
235
-
236
- # Add com.microsoft opset for Cast operations if needed
237
- has_com_microsoft = False
238
- for opset in model.opset_import:
239
- if opset.domain == "com.microsoft":
240
- has_com_microsoft = True
241
- break
242
-
243
- if not has_com_microsoft:
244
- opset = model.opset_import.add()
245
- opset.domain = "com.microsoft"
246
- opset.version = 1
247
-
248
- # Use onnxoptimizer's cast optimization to handle all intermediate values
249
- # First, optimize the model to ensure clean graph structure
250
- LOGGER.info("Running onnxoptimizer passes...")
251
-
252
- # Check available optimization passes
253
- available_passes = onnxoptimizer.get_available_passes()
254
- LOGGER.debug(f"Available passes: {len(available_passes)}")
255
-
256
- # Run cast optimization pass which handles FP16 conversion
257
- try:
258
- # The cast_optimization pass handles type propagation
259
- model = onnxoptimizer.optimize(
260
- model,
261
- passes=['cast_optimization'],
262
- fixed_point=False
263
- )
264
- LOGGER.info("Applied cast_optimization pass")
265
- except Exception as e:
266
- LOGGER.warning(f"cast_optimization failed: {e}, trying alternative approach")
267
- # Alternative: manually handle the conversion
268
-
269
- # If still has float32 types, use a more aggressive approach
270
- model = _aggressive_fp16_cast(model)
271
-
272
- # Save the FP16 model
273
- onnx.save(model, str(onnx_path))
274
-
275
- size_mb = Path(onnx_path).stat().st_size / (1024 * 1024)
276
- LOGGER.info(f"FP16 model saved: {onnx_path} ({size_mb:.2f} MB)")
277
- return onnx_path
278
-
279
-
280
- def _aggressive_fp16_cast(model: onnx.ModelProto) -> onnx.ModelProto:
281
- """Aggressively cast all float32 values to float16.
282
-
283
- This function converts initializers and adds Cast nodes for intermediate
284
- values to ensure type consistency throughout the graph.
285
- """
286
- LOGGER.info("Applying aggressive FP16 casting...")
287
-
288
- # Run shape inference to populate value_info with all intermediate values
289
- LOGGER.info("Running shape inference to find all intermediate values...")
290
- try:
291
- model = onnx.shape_inference.infer_shapes(model)
292
- except Exception as e:
293
- LOGGER.warning(f"Shape inference failed: {e}")
294
-
295
- # Step 1: Convert all initializers (weights) directly to float16
296
- initializer_count = 0
297
- for tensor in model.graph.initializer:
298
- if tensor.data_type == onnx.TensorProto.FLOAT:
299
- float16_data = onnx.numpy_helper.to_array(tensor).astype(np.float16)
300
- tensor.CopyFrom(onnx.numpy_helper.from_array(float16_data, tensor.name))
301
- initializer_count += 1
302
-
303
- LOGGER.info(f"Converted {initializer_count} initializers to FP16")
304
-
305
- # Step 2: Convert graph inputs to FP16
306
- initializer_names = {t.name for t in model.graph.initializer}
307
- for inp in model.graph.input:
308
- if inp.name in initializer_names:
309
- continue
310
- if inp.type.tensor_type.elem_type == onnx.TensorProto.FLOAT:
311
- inp.type.tensor_type.elem_type = onnx.TensorProto.FLOAT16
312
-
313
- # Step 3: Convert graph outputs to FP16
314
- for out in model.graph.output:
315
- if out.type.tensor_type.elem_type == onnx.TensorProto.FLOAT:
316
- out.type.tensor_type.elem_type = onnx.TensorProto.FLOAT16
317
-
318
- # Step 4: Find all float32 values (from initializers, value_info, and node outputs)
319
- values_to_cast = set()
320
-
321
- # From value_info
322
- for vi in model.graph.value_info:
323
- if vi.type.tensor_type.elem_type == onnx.TensorProto.FLOAT:
324
- values_to_cast.add(vi.name)
325
-
326
- # Also check node outputs - some may be float32 but not in value_info
327
- node_output_types = {} # output_name -> type
328
- for node in model.graph.node:
329
- for out in node.output:
330
- node_output_types[out] = node.op_type
331
-
332
- LOGGER.info(f"Found {len(values_to_cast)} intermediate float32 values from value_info")
333
-
334
- if not values_to_cast:
335
- return model
336
-
337
- # Step 5: Create cast nodes for intermediate values
338
- cast_nodes = []
339
- cast_map = {} # original_name -> casted_name
340
- node_name_counter = 0
341
-
342
- for val_name in values_to_cast:
343
- cast_name = f"{val_name}_fp16"
344
- cast_map[val_name] = cast_name
345
-
346
- cast_node = onnx.helper.make_node(
347
- 'Cast',
348
- inputs=[val_name],
349
- outputs=[cast_name],
350
- to=onnx.TensorProto.FLOAT16,
351
- name=f"Cast_{node_name_counter}"
352
- )
353
- cast_nodes.append(cast_node)
354
- node_name_counter += 1
355
-
356
- LOGGER.info(f"Created {len(cast_nodes)} Cast nodes for intermediate values")
357
-
358
- # Step 6: Update node inputs to use casted values
359
- for node in model.graph.node:
360
- for i, inp in enumerate(node.input):
361
- if inp in cast_map:
362
- node.input[i] = cast_map[inp]
363
-
364
- # Step 7: Update value_info to reflect new types
365
- new_value_info = []
366
- for vi in model.graph.value_info:
367
- if vi.name in cast_map:
368
- shape = onnx.helper.get_tensor_shape(vi)
369
- new_vi = onnx.helper.make_tensor_value_info(
370
- cast_map[vi.name],
371
- onnx.TensorProto.FLOAT16,
372
- shape
373
- )
374
- new_value_info.append(new_vi)
375
- else:
376
- new_value_info.append(vi)
377
-
378
- model.graph.ClearField('value_info')
379
- for vi in new_value_info:
380
- model.graph.value_info.append(vi)
381
-
382
- # Step 8: Insert cast nodes at the beginning of the graph
383
- insert_indices = []
384
- cast_outputs = set(cast_map.values())
385
- for i, node in enumerate(model.graph.node):
386
- for inp in node.input:
387
- if inp in cast_outputs:
388
- insert_indices.append(i)
389
- break
390
-
391
- insert_index = min(insert_indices) if insert_indices else len(model.graph.node)
392
-
393
- new_nodes = list(model.graph.node[:insert_index]) + cast_nodes + list(model.graph.node[insert_index:])
394
- model.graph.ClearField('node')
395
- for node in new_nodes:
396
- model.graph.node.append(node)
397
-
398
- return model
399
-
400
-
401
- def _cast_floats_to_fp16(model: onnx.ModelProto) -> onnx.ModelProto:
402
- """Add Cast nodes to convert all float32 tensors to float16.
403
-
404
- This approach checks each node's inputs and adds Cast nodes for any float32
405
- inputs when the node also has float16 inputs, ensuring type consistency.
406
- """
407
- # Build a map of known value types
408
- value_types = {}
409
-
410
- # From initializers
411
- for tensor in model.graph.initializer:
412
- value_types[tensor.name] = tensor.data_type
413
-
414
- # From inputs
415
- initializer_names = {t.name for t in model.graph.initializer}
416
- for inp in model.graph.input:
417
- if inp.name not in initializer_names:
418
- value_types[inp.name] = inp.type.tensor_type.elem_type
419
-
420
- # From outputs
421
- for out in model.graph.output:
422
- value_types[out.name] = out.type.tensor_type.elem_type
423
-
424
- # From value_info
425
- for vi in model.graph.value_info:
426
- value_types[vi.name] = vi.type.tensor_type.elem_type
427
-
428
- # Track values that are FP16 (to avoid re-casting)
429
- fp16_values = {k for k, v in value_types.items() if v == onnx.TensorProto.FLOAT16}
430
-
431
- LOGGER.info(f"Found {len(fp16_values)} FP16 values in graph")
432
-
433
- # Find all float32 values that need casting
434
- float32_values = [k for k, v in value_types.items() if v == onnx.TensorProto.FLOAT]
435
- LOGGER.info(f"Found {len(float32_values)} float32 values to cast to float16")
436
-
437
- if not float32_values:
438
- return model
439
-
440
- # Create Cast nodes for each value that needs conversion
441
- cast_nodes = []
442
- cast_outputs = set()
443
- node_name_counter = 0
444
-
445
- # Create a mapping of original values to their casted versions
446
- cast_map = {}
447
-
448
- for val_name in float32_values:
449
- if val_name in cast_outputs or val_name in fp16_values:
450
- continue
451
-
452
- cast_name = f"{val_name}_to_fp16"
453
- cast_map[val_name] = cast_name
454
- cast_outputs.add(cast_name)
455
-
456
- cast_node = onnx.helper.make_node(
457
- 'Cast',
458
- inputs=[val_name],
459
- outputs=[cast_name],
460
- to=onnx.TensorProto.FLOAT16,
461
- name=f"Cast_{node_name_counter}"
462
- )
463
- cast_nodes.append(cast_node)
464
- node_name_counter += 1
465
-
466
- LOGGER.info(f"Created {len(cast_nodes)} Cast nodes")
467
-
468
- if not cast_nodes:
469
- return model
470
-
471
- # Update node inputs to use casted values
472
- for node in model.graph.node:
473
- for i, inp in enumerate(node.input):
474
- if inp in cast_map:
475
- node.input[i] = cast_map[inp]
476
-
477
- # Update value_info to reflect new types
478
- new_value_info = []
479
- for vi in model.graph.value_info:
480
- if vi.name in cast_map:
481
- # Create new value_info with FP16 type
482
- shape = onnx.helper.get_tensor_shape(vi)
483
- new_vi = onnx.helper.make_tensor_value_info(
484
- cast_map[vi.name],
485
- onnx.TensorProto.FLOAT16,
486
- shape
487
- )
488
- new_value_info.append(new_vi)
489
- else:
490
- new_value_info.append(vi)
491
-
492
- model.graph.ClearField('value_info')
493
- for vi in new_value_info:
494
- model.graph.value_info.append(vi)
495
-
496
- # Insert Cast nodes at the beginning of the graph (before any consumer)
497
- insert_indices = []
498
- for i, node in enumerate(model.graph.node):
499
- for inp in node.input:
500
- if inp in cast_outputs:
501
- insert_indices.append(i)
502
- break
503
-
504
- if insert_indices:
505
- insert_index = min(insert_indices)
506
- else:
507
- insert_index = len(model.graph.node)
508
-
509
- # Insert cast nodes
510
- new_nodes = list(model.graph.node[:insert_index]) + cast_nodes + list(model.graph.node[insert_index:])
511
- model.graph.ClearField('node')
512
- for node in new_nodes:
513
- model.graph.node.append(node)
514
-
515
- return model
516
-
517
-
518
- def _ensure_fp16_types(model: onnx.ModelProto) -> onnx.ModelProto:
519
- """Ensure all float tensors in the model are FP16.
520
-
521
- This function traverses the graph and adds Cast nodes where needed
522
- to convert any remaining float32 tensors to float16.
523
- """
524
- return _cast_floats_to_fp16(model)
525
-
526
-
527
- def convert_to_onnx(predictor, output_path, input_shape=(1536, 1536), use_external_data=None, fp16=False):
528
  LOGGER.info("Exporting to ONNX format...")
529
  predictor.depth_alignment.scale_map_estimator = None
530
  model = SharpModelTraceable(predictor)
@@ -544,15 +217,11 @@ def convert_to_onnx(predictor, output_path, input_shape=(1536, 1536), use_extern
544
 
545
  LOGGER.info(f"Exporting to ONNX: {output_path}")
546
 
547
- # Dynamic axes: opacities has shape (1, N) so axis 0 is the batch, axis 1 is num_gaussians
548
- # All other outputs have shape (1, N, C) where C is 3, 3, 4, 3 respectively
549
  dynamic_axes = {}
550
  for name in OUTPUT_NAMES:
551
  if name == "opacities_alpha_channel":
552
- # opacities is 2D: (batch, num_gaussians)
553
  dynamic_axes[name] = {0: 'batch', 1: 'num_gaussians'}
554
  else:
555
- # All other outputs are 3D: (batch, num_gaussians, channels)
556
  dynamic_axes[name] = {0: 'batch', 1: 'num_gaussians'}
557
 
558
  torch.onnx.export(
@@ -561,42 +230,29 @@ def convert_to_onnx(predictor, output_path, input_shape=(1536, 1536), use_extern
561
  input_names=['image', 'disparity_factor'],
562
  output_names=OUTPUT_NAMES,
563
  dynamic_axes=dynamic_axes,
564
- opset_version=15, # Use opset 15 for better browser compatibility
 
565
  )
566
 
567
- # Handle external data based on use_external_data parameter
568
- try:
569
- model_proto = onnx.load(str(output_path))
570
- model_size_mb = model_proto.ByteSize() / (1024 * 1024)
571
- LOGGER.info(f"Model size: {model_size_mb:.2f} MB")
572
-
573
- # Default: use external data for models > 100MB (not typical for browser)
574
- # use_external_data=True: always use external data
575
- # use_external_data=False: never use external data (inline mode for browser)
576
- use_ext = use_external_data if use_external_data is not None else (model_size_mb > 100)
577
-
578
- if use_ext:
579
- LOGGER.info("Saving with external data format...")
580
- data_path = output_path.with_suffix('.onnx.data')
581
- onnx.save_model(model_proto, str(output_path), save_as_external_data=True,
582
- all_tensors_to_one_file=True, location=data_path.name)
583
- LOGGER.info(f"External data saved to: {data_path}")
584
- else:
585
- LOGGER.info("Using inline data format (no external .onnx.data file needed)")
586
- except Exception as e:
587
- LOGGER.warning(f"External data format check failed: {e}")
588
-
589
- try:
590
- onnx.checker.check_model(str(output_path))
591
- LOGGER.info("ONNX model validation passed")
592
- except Exception as e:
593
- LOGGER.warning(f"ONNX model validation skipped: {e}")
594
-
595
- # Apply FP16 quantization if requested
596
- if fp16:
597
- convert_to_fp16(output_path)
598
-
599
- cleanup_extraneous_files()
600
  return output_path
601
 
602
 
@@ -616,7 +272,7 @@ def load_and_preprocess_image(image_path, target_size=(1536, 1536)):
616
  if orig_size is None:
617
  orig_size = (image_np.shape[1], image_np.shape[0])
618
  LOGGER.info(f"Original size: {orig_size}, focal: {f_px:.2f}px")
619
- tensor = torch.from_numpy(image_np).float() / 255.0
620
  tensor = tensor.permute(2, 0, 1)
621
  if (orig_size[0], orig_size[1]) != (target_size[1], target_size[0]):
622
  LOGGER.info(f"Resizing to {target_size[1]}x{target_size[0]}")
@@ -825,10 +481,9 @@ def main():
825
  parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose logging")
826
  parser.add_argument("--input-image", type=Path, default=None, action="append", help="Path to input image for validation")
827
  parser.add_argument("--no-external-data", action="store_true", help="Save model with inline data (no .onnx.data file needed)")
828
- parser.add_argument("--fp16", action="store_true", help="Quantize model to FP16 precision (half-precision)")
829
- parser.add_argument("--tolerance-mean", type=float, default=None, help="Custom mean angular tolerance in degrees")
830
- parser.add_argument("--tolerance-p99", type=float, default=None, help="Custom P99 angular tolerance in degrees")
831
- parser.add_argument("--tolerance-max", type=float, default=None, help="Custom max angular tolerance in degrees")
832
 
833
  args = parser.parse_args()
834
 
@@ -841,17 +496,10 @@ def main():
841
  input_shape = (args.height, args.width)
842
 
843
  LOGGER.info(f"Converting to ONNX: {args.output}")
844
- # Use inline data format for browser deployment (--no-external-data flag or default for web)
845
- use_external_data = not args.no_external_data
846
- convert_to_onnx(predictor, args.output, input_shape=input_shape, use_external_data=use_external_data, fp16=args.fp16)
847
  LOGGER.info(f"ONNX model saved to {args.output}")
848
 
849
- # Skip validation for FP16 models since they have inherent precision differences from FP32
850
- if args.validate and args.fp16:
851
- LOGGER.info("Validation skipped for FP16 model (precision differences expected)")
852
- LOGGER.info("Conversion complete!")
853
- return 0
854
-
855
  if args.validate:
856
  if args.input_image:
857
  for img_path in args.input_image:
@@ -878,6 +526,7 @@ def main():
878
  LOGGER.error("Validation failed!")
879
  return 1
880
 
 
881
  LOGGER.info("Conversion complete!")
882
  return 0
883
 
 
9
 
10
  import numpy as np
11
  import onnx
12
+ import onnx.external_data_helper as onnx_external_data
13
  import onnxoptimizer
14
  import onnxruntime as ort
15
  import torch
 
44
  self.random_tolerances = {
45
  "mean_vectors_3d_positions": 0.001,
46
  "singular_values_scales": 0.0001,
47
+ "quaternions_rotations": 2.0, # Increased for ONNX numerical precision
48
  "colors_rgb_linear": 0.002,
49
  "opacities_alpha_channel": 0.005,
50
  }
 
52
  self.image_tolerances = {
53
  "mean_vectors_3d_positions": 3.5,
54
  "singular_values_scales": 0.035,
55
+ "quaternions_rotations": 2.0, # Increased for ONNX numerical precision
56
  "colors_rgb_linear": 0.01,
57
  "opacities_alpha_channel": 0.05,
58
  }
59
  if self.angular_tolerances_random is None:
60
+ self.angular_tolerances_random = {"mean": 0.01, "p99": 0.1, "p99_9": 1.0, "max": 10.0}
61
  if self.angular_tolerances_image is None:
62
  self.angular_tolerances_image = {"mean": 0.2, "p99": 2.0, "p99_9": 5.0, "max": 25.0}
63
 
 
148
  try:
149
  if onnx_path.exists():
150
  onnx_path.unlink()
151
+ #LOGGER.info(f"Removed {onnx_path}")
152
  except Exception as e:
153
  LOGGER.warning(f"Could not remove {onnx_path}: {e}")
154
 
 
157
  try:
158
  if data_path.exists():
159
  data_path.unlink()
160
+ #LOGGER.info(f"Removed {data_path}")
161
  except Exception as e:
162
  LOGGER.warning(f"Could not remove {data_path}: {e}")
163
 
 
168
  for f in glob.glob(pattern):
169
  try:
170
  Path(f).unlink()
171
+ #LOGGER.info(f"Removed temporary file {f}")
172
  except Exception:
173
  pass
174
 
 
197
  return predictor
198
 
199
 
200
+ def convert_to_onnx(predictor, output_path, input_shape=(1536, 1536), use_external_data=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  LOGGER.info("Exporting to ONNX format...")
202
  predictor.depth_alignment.scale_map_estimator = None
203
  model = SharpModelTraceable(predictor)
 
217
 
218
  LOGGER.info(f"Exporting to ONNX: {output_path}")
219
 
 
 
220
  dynamic_axes = {}
221
  for name in OUTPUT_NAMES:
222
  if name == "opacities_alpha_channel":
 
223
  dynamic_axes[name] = {0: 'batch', 1: 'num_gaussians'}
224
  else:
 
225
  dynamic_axes[name] = {0: 'batch', 1: 'num_gaussians'}
226
 
227
  torch.onnx.export(
 
230
  input_names=['image', 'disparity_factor'],
231
  output_names=OUTPUT_NAMES,
232
  dynamic_axes=dynamic_axes,
233
+ opset_version=15,
234
+ external_data=True, # Save weights to external .onnx.data file for large models
235
  )
236
 
237
+ # Verify the external data file was created
238
+ data_path = output_path.with_suffix('.onnx.data')
239
+ if data_path.exists():
240
+ data_size_gb = data_path.stat().st_size / (1024**3)
241
+ LOGGER.info(f"External data file saved: {data_path} ({data_size_gb:.2f} GB)")
242
+ else:
243
+ LOGGER.warning("External data file not found - model may be inline or external data not created yet")
244
+ # Try to convert to external data format if not created automatically
245
+ try:
246
+ model_onnx = onnx.load(str(output_path))
247
+ onnx.external_data_helper.convert_model_to_external_data(model_onnx, all_tensors_to_one_file=True)
248
+ onnx.save(model_onnx, str(output_path))
249
+ if data_path.exists():
250
+ data_size_gb = data_path.stat().st_size / (1024**3)
251
+ LOGGER.info(f"External data file created: {data_path} ({data_size_gb:.2f} GB)")
252
+ except Exception as e:
253
+ LOGGER.warning(f"Could not create external data file: {e}")
254
+
255
+ LOGGER.info(f"ONNX model saved to {output_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  return output_path
257
 
258
 
 
272
  if orig_size is None:
273
  orig_size = (image_np.shape[1], image_np.shape[0])
274
  LOGGER.info(f"Original size: {orig_size}, focal: {f_px:.2f}px")
275
+ tensor = torch.from_numpy(image_np.copy()).float() / 255.0
276
  tensor = tensor.permute(2, 0, 1)
277
  if (orig_size[0], orig_size[1]) != (target_size[1], target_size[0]):
278
  LOGGER.info(f"Resizing to {target_size[1]}x{target_size[0]}")
 
481
  parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose logging")
482
  parser.add_argument("--input-image", type=Path, default=None, action="append", help="Path to input image for validation")
483
  parser.add_argument("--no-external-data", action="store_true", help="Save model with inline data (no .onnx.data file needed)")
484
+ parser.add_argument("--tolerance-mean", type=float, default=None, help="Custom mean angular tolerance for quaternion validation")
485
+ parser.add_argument("--tolerance-p99", type=float, default=None, help="Custom p99 angular tolerance for quaternion validation")
486
+ parser.add_argument("--tolerance-max", type=float, default=None, help="Custom max angular tolerance for quaternion validation")
 
487
 
488
  args = parser.parse_args()
489
 
 
496
  input_shape = (args.height, args.width)
497
 
498
  LOGGER.info(f"Converting to ONNX: {args.output}")
499
+ # Always use inline data for simplicity and compatibility
500
+ convert_to_onnx(predictor, args.output, input_shape=input_shape, use_external_data=False)
 
501
  LOGGER.info(f"ONNX model saved to {args.output}")
502
 
 
 
 
 
 
 
503
  if args.validate:
504
  if args.input_image:
505
  for img_path in args.input_image:
 
526
  LOGGER.error("Validation failed!")
527
  return 1
528
 
529
+ cleanup_extraneous_files()
530
  LOGGER.info("Conversion complete!")
531
  return 0
532
 
inference_onnx.py CHANGED
@@ -75,24 +75,10 @@ def run_inference(onnx_path: str | Path, image: np.ndarray, disparity_factor: fl
75
 
76
  LOGGER.info(f"Loading ONNX model: {onnx_path}")
77
 
78
- # Try execution providers in order of preference
79
- # CoreML is best for Apple Silicon (handles FP16 automatically)
80
- # CPU is fallback for models that CoreML doesn't support
81
-
82
- # Use all string providers with separate provider_options list
83
- providers = ['CoreMLExecutionProvider', 'CPUExecutionProvider']
84
- provider_options = [{'AccelerateInference': True}, {}]
85
-
86
- try:
87
- session = ort.InferenceSession(str(onnx_path), providers=providers, provider_options=provider_options)
88
- LOGGER.info("Using CoreMLExecutionProvider for inference")
89
- except Exception as e:
90
- LOGGER.warning(f"CoreML execution failed, trying CPU: {e}")
91
- try:
92
- session = ort.InferenceSession(str(onnx_path), providers=['CPUExecutionProvider'])
93
- LOGGER.info("Using CPUExecutionProvider for inference")
94
- except Exception as cpu_e:
95
- raise RuntimeError(f"Failed to load ONNX model: {cpu_e}")
96
 
97
  input_names = [inp.name for inp in session.get_inputs()]
98
  output_names = [out.name for out in session.get_outputs()]
@@ -303,4 +289,3 @@ def main():
303
 
304
  if __name__ == "__main__":
305
  main()
306
-
 
75
 
76
  LOGGER.info(f"Loading ONNX model: {onnx_path}")
77
 
78
+ # Use CPUExecutionProvider for universal compatibility
79
+ # Works on all platforms and handles large models with external data files
80
+ session = ort.InferenceSession(str(onnx_path), providers=['CPUExecutionProvider'])
81
+ LOGGER.info("Using CPUExecutionProvider for inference")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  input_names = [inp.name for inp in session.get_inputs()]
84
  output_names = [out.name for out in session.get_outputs()]
 
289
 
290
  if __name__ == "__main__":
291
  main()