Kyle Pearson commited on
Commit
595d711
·
1 Parent(s): 6d257c6

Add quaternion validation, enable dynamic tolerance config, optimize ONNX export, fix race conditions in cleanup, add image-based validation, improve structured output reporting.

Browse files
Files changed (1) hide show
  1. convert_onnx.py +395 -534
convert_onnx.py CHANGED
@@ -1,13 +1,10 @@
1
- """Convert SHARP PyTorch model to ONNX format.
2
-
3
- This script converts the SHARP (Sharp Monocular View Synthesis) model
4
- from PyTorch (.pt) to ONNX (.onnx) format for deployment on various platforms.
5
- """
6
 
7
  from __future__ import annotations
8
 
9
  import argparse
10
  import logging
 
11
  from pathlib import Path
12
 
13
  import numpy as np
@@ -15,31 +12,105 @@ import onnx
15
  import onnxruntime as ort
16
  import torch
17
  import torch.nn as nn
 
18
 
19
- # Import SHARP model components
20
  from sharp.models import PredictorParams, create_predictor
21
  from sharp.models.predictor import RGBGaussianPredictor
 
22
 
23
  LOGGER = logging.getLogger(__name__)
24
-
25
  DEFAULT_MODEL_URL = "https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt"
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- class SharpModelTraceable(nn.Module):
29
- """Fully traceable version of SHARP for ONNX export.
30
-
31
- This version removes all dynamic control flow and makes the model
32
- fully traceable with torch.jit.trace.
33
- """
 
 
 
34
 
35
- def __init__(self, predictor: RGBGaussianPredictor):
36
- """Initialize the traceable wrapper.
37
 
38
- Args:
39
- predictor: The SHARP RGBGaussianPredictor model.
40
- """
41
  super().__init__()
42
- # Copy all submodules
43
  self.init_model = predictor.init_model
44
  self.feature_model = predictor.feature_model
45
  self.monodepth_model = predictor.monodepth_model
@@ -47,592 +118,382 @@ class SharpModelTraceable(nn.Module):
47
  self.gaussian_composer = predictor.gaussian_composer
48
  self.depth_alignment = predictor.depth_alignment
49
 
50
- def forward(
51
- self,
52
- image: torch.Tensor,
53
- disparity_factor: torch.Tensor
54
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
55
- """Run inference with traceable forward pass.
56
-
57
- Args:
58
- image: Input image tensor of shape (1, 3, H, W) in range [0, 1].
59
- disparity_factor: Disparity factor tensor of shape (1,).
60
-
61
- Returns:
62
- Tuple of 5 tensors representing 3D Gaussians.
63
- """
64
- # Estimate depth using monodepth
65
- monodepth_output = self.monodepth_model(image)
66
- monodepth_disparity = monodepth_output.disparity
67
-
68
- # Convert disparity to depth with higher precision
69
- disparity_factor_expanded = disparity_factor[:, None, None, None]
70
-
71
- # Cast to float64 for more precise division, then back to float32
72
- disparity_clamped = monodepth_disparity.clamp(min=1e-6, max=1e4)
73
- monodepth = disparity_factor_expanded.double() / disparity_clamped.double()
74
- monodepth = monodepth.float()
75
-
76
- # Apply depth alignment (inference mode)
77
- monodepth, _ = self.depth_alignment(monodepth, None, monodepth_output.decoder_features)
78
-
79
- # Initialize gaussians
80
- init_output = self.init_model(image, monodepth)
81
-
82
- # Extract features
83
- image_features = self.feature_model(
84
- init_output.feature_input,
85
- encodings=monodepth_output.output_features
86
- )
87
-
88
- # Predict deltas
89
- delta_values = self.prediction_head(image_features)
90
-
91
- # Compose final gaussians
92
- gaussians = self.gaussian_composer(
93
- delta=delta_values,
94
- base_values=init_output.gaussian_base_values,
95
- global_scale=init_output.global_scale,
96
- )
97
-
98
- # Normalize quaternions for consistent validation and inference
99
- quaternions = gaussians.quaternions
100
-
101
- # Use double precision for quaternion normalization to reduce numerical errors
102
- quaternions_fp64 = quaternions.double()
103
- quat_norm_sq = torch.sum(quaternions_fp64 * quaternions_fp64, dim=-1, keepdim=True)
104
- quat_norm = torch.sqrt(torch.clamp(quat_norm_sq, min=1e-16))
105
- quaternions_normalized = quaternions_fp64 / quat_norm
106
-
107
- # Apply sign canonicalization for consistent representation
108
- # Find the component with the largest absolute value
109
- abs_quat = torch.abs(quaternions_normalized)
110
- max_idx = torch.argmax(abs_quat, dim=-1, keepdim=True)
111
-
112
- # Create one-hot selector for the max component
113
- one_hot = torch.zeros_like(quaternions_normalized)
114
  one_hot.scatter_(-1, max_idx, 1.0)
 
 
 
115
 
116
- # Get the sign of the max component
117
- max_component_sign = torch.sum(quaternions_normalized * one_hot, dim=-1, keepdim=True)
118
-
119
- # Canonicalize: flip if max component is negative
120
- quaternions = torch.where(max_component_sign < 0, -quaternions_normalized, quaternions_normalized).float()
121
-
122
- return (
123
- gaussians.mean_vectors,
124
- gaussians.singular_values,
125
- quaternions,
126
- gaussians.colors,
127
- gaussians.opacities,
128
- )
129
 
130
-
131
- def cleanup_onnx_files(onnx_path: Path) -> None:
132
- """Remove ONNX file and any associated external data files.
133
-
134
- Args:
135
- onnx_path: Path to the ONNX file.
136
- """
137
  try:
138
  if onnx_path.exists():
139
- LOGGER.info(f"Removing existing ONNX file: {onnx_path}")
140
  onnx_path.unlink()
141
- except Exception as e:
142
- LOGGER.warning(f"Could not remove ONNX file {onnx_path}: {e}")
143
-
144
- # Also try to remove external data file
145
- external_data_path = onnx_path.with_suffix('.onnx.data')
146
  try:
147
- if external_data_path.exists():
148
- LOGGER.info(f"Removing existing external data file: {external_data_path}")
149
- external_data_path.unlink()
150
- except Exception as e:
151
- LOGGER.warning(f"Could not remove external data file {external_data_path}: {e}")
152
-
153
-
154
- def cleanup_extraneous_onnx_files() -> None:
155
- """Remove extraneous files created during ONNX conversion.
156
-
157
- This function removes intermediate files that PyTorch/ONNX creates
158
- during the export process but are not needed for the final model.
159
- """
160
- import glob
161
- import os
162
 
163
- # Patterns of extraneous files to remove
164
- patterns = [
165
- "onnx__*",
166
- "monodepth_*",
167
- "feature_model*",
168
- "_Constant_*",
169
- "_init_model_*"
170
- ]
171
 
172
- files_removed = 0
173
-
174
- for pattern in patterns:
175
- # Use glob to find files matching the pattern
176
- matching_files = glob.glob(pattern)
177
- for file_path in matching_files:
178
  try:
179
- os.remove(file_path)
180
- files_removed += 1
181
- LOGGER.debug(f"Removed extraneous file: {file_path}")
182
- except Exception as e:
183
- LOGGER.warning(f"Could not remove file {file_path}: {e}")
184
-
185
- if files_removed > 0:
186
- LOGGER.info(f"Cleaned up {files_removed} extraneous ONNX conversion files")
187
-
188
-
189
- def load_sharp_model(checkpoint_path: Path | None = None) -> RGBGaussianPredictor:
190
- """Load SHARP model from checkpoint.
191
 
192
- Args:
193
- checkpoint_path: Path to the .pt checkpoint file.
194
- If None, downloads the default model.
195
 
196
- Returns:
197
- The loaded RGBGaussianPredictor model in eval mode.
198
- """
199
  if checkpoint_path is None:
200
- LOGGER.info("Downloading default model from %s", DEFAULT_MODEL_URL)
201
  state_dict = torch.hub.load_state_dict_from_url(DEFAULT_MODEL_URL, progress=True)
202
  else:
203
- LOGGER.info("Loading checkpoint from %s", checkpoint_path)
204
  state_dict = torch.load(checkpoint_path, weights_only=True, map_location="cpu")
205
-
206
- # Create model with default parameters
207
  predictor = create_predictor(PredictorParams())
208
  predictor.load_state_dict(state_dict)
209
  predictor.eval()
210
-
211
  return predictor
212
 
213
 
214
- def convert_to_onnx(
215
- predictor: RGBGaussianPredictor,
216
- output_path: Path,
217
- input_shape: tuple[int, int] = (1536, 1536),
218
- ) -> Path:
219
- """Export SHARP model to ONNX format.
220
-
221
- Args:
222
- predictor: The SHARP RGBGaussianPredictor model.
223
- output_path: Path to save the .onnx file.
224
- input_shape: Input image shape (height, width).
225
-
226
- Returns:
227
- Path to the saved ONNX file.
228
- """
229
  LOGGER.info("Exporting to ONNX format...")
230
-
231
- # Ensure depth alignment is disabled for inference
232
  predictor.depth_alignment.scale_map_estimator = None
233
-
234
- # Create traceable wrapper
235
- model_wrapper = SharpModelTraceable(predictor)
236
- model_wrapper.eval()
237
-
238
- # Pre-warm the model
239
  LOGGER.info("Pre-warming model...")
240
  with torch.no_grad():
241
  for _ in range(3):
242
- warm_image = torch.randn(1, 3, input_shape[0], input_shape[1])
243
- warm_disparity = torch.tensor([1.0])
244
- _ = model_wrapper(warm_image, warm_disparity)
245
-
246
- # Clean up any existing ONNX files
247
  cleanup_onnx_files(output_path)
248
-
249
- # Create example inputs
250
- height, width = input_shape
251
  torch.manual_seed(42)
252
- example_image = torch.randn(1, 3, height, width)
253
- example_disparity_factor = torch.tensor([1.0])
254
-
255
- # Export to ONNX
256
  LOGGER.info(f"Exporting to ONNX: {output_path}")
257
-
 
 
 
 
 
 
 
 
258
  try:
259
- # Export with external data format to handle large models (>2GB)
260
- torch.onnx.export(
261
- model_wrapper,
262
- (example_image, example_disparity_factor),
263
- str(output_path),
264
- export_params=True,
265
- verbose=False,
266
- input_names=['image', 'disparity_factor'],
267
- output_names=[
268
- 'mean_vectors_3d_positions',
269
- 'singular_values_scales',
270
- 'quaternions_rotations',
271
- 'colors_rgb_linear',
272
- 'opacities_alpha_channel'
273
- ],
274
- dynamic_axes={
275
- 'mean_vectors_3d_positions': {1: 'num_gaussians'},
276
- 'singular_values_scales': {1: 'num_gaussians'},
277
- 'quaternions_rotations': {1: 'num_gaussians'},
278
- 'colors_rgb_linear': {1: 'num_gaussians'},
279
- 'opacities_alpha_channel': {1: 'num_gaussians'}
280
- },
281
- opset_version=17,
282
- )
283
-
284
- # For models >2GB, save with external data format
285
- try:
286
- model_proto = onnx.load(str(output_path))
287
- model_size = model_proto.ByteSize()
288
- if model_size > 2e9: # 2GB
289
- LOGGER.info(f"Model size {model_size/1e9:.2f}GB > 2GB, converting to external data format...")
290
- onnx.save_model(
291
- model_proto,
292
- str(output_path),
293
- save_as_external_data=True,
294
- all_tensors_to_one_file=True,
295
- location=f"{output_path.stem}.onnx.data",
296
- size_threshold=1024,
297
- convert_attribute=False,
298
- )
299
- LOGGER.info("Successfully saved with external data format")
300
- except Exception as e:
301
- LOGGER.warning(f"Could not check/convert to external data format: {e}")
302
-
303
- LOGGER.info("ONNX export successful")
304
  except Exception as e:
305
- LOGGER.error(f"ONNX export failed: {e}")
306
- raise
307
-
308
- # Verify ONNX model
309
  try:
310
  onnx.checker.check_model(str(output_path))
311
  LOGGER.info("ONNX model validation passed")
312
  except Exception as e:
313
  LOGGER.warning(f"ONNX model validation skipped: {e}")
314
-
315
- # Clean up extraneous files created during ONNX conversion
316
- cleanup_extraneous_onnx_files()
317
-
318
  return output_path
319
 
320
 
321
- def validate_onnx_model(
322
- onnx_path: Path,
323
- pytorch_model: RGBGaussianPredictor,
324
- input_shape: tuple[int, int] = (1536, 1536),
325
- tolerance: float = 0.01,
326
- ) -> bool:
327
- """Validate ONNX model outputs against PyTorch model.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
- Args:
330
- onnx_path: Path to the ONNX model file.
331
- pytorch_model: The original PyTorch model.
332
- input_shape: Input image shape (height, width).
333
- tolerance: Maximum allowed difference between outputs.
334
 
335
- Returns:
336
- True if validation passes, False otherwise.
337
- """
338
  LOGGER.info("Validating ONNX model against PyTorch...")
339
-
340
- height, width = input_shape
341
-
342
- # Set seeds for reproducibility
343
  np.random.seed(42)
344
  torch.manual_seed(42)
345
-
346
- # Create test input
347
- test_image_np = np.random.rand(1, 3, height, width).astype(np.float32)
348
- test_disparity = np.array([1.0], dtype=np.float32)
349
-
350
- # Run PyTorch model
351
- test_image_pt = torch.from_numpy(test_image_np)
352
- test_disparity_pt = torch.from_numpy(test_disparity)
353
-
354
- traceable_wrapper = SharpModelTraceable(pytorch_model)
355
- traceable_wrapper.eval()
356
-
357
  with torch.no_grad():
358
- pt_outputs = traceable_wrapper(test_image_pt, test_disparity_pt)
359
-
360
- # Run ONNX model
361
- try:
362
- session_options = ort.SessionOptions()
363
- session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
364
-
365
- providers = ['CPUExecutionProvider']
366
- session = ort.InferenceSession(str(onnx_path), session_options, providers=providers)
367
-
368
- onnx_inputs = {
369
- "image": test_image_np,
370
- "disparity_factor": test_disparity,
371
- }
372
-
373
- onnx_outputs = session.run(None, onnx_inputs)
374
-
375
- output_names = [
376
- 'mean_vectors_3d_positions',
377
- 'singular_values_scales',
378
- 'quaternions_rotations',
379
- 'colors_rgb_linear',
380
- 'opacities_alpha_channel'
381
- ]
382
-
383
- if len(onnx_outputs) != len(output_names):
384
- LOGGER.warning(f"ONNX outputs count mismatch: expected {len(output_names)}, got {len(onnx_outputs)}")
385
- onnx_output_dict = {f"output_{i}": output for i, output in enumerate(onnx_outputs)}
386
- else:
387
- onnx_output_dict = dict(zip(output_names, onnx_outputs))
388
-
389
- except Exception as e:
390
- LOGGER.error(f"Failed to run ONNX model: {e}")
391
- return False
392
-
393
- # Debug: Print shapes
394
- LOGGER.info(f"PyTorch outputs shapes: {[o.shape for o in pt_outputs]}")
395
- LOGGER.info(f"ONNX outputs shapes: {[v.shape for v in onnx_output_dict.values()]}")
396
-
397
- # Compare outputs with per-output tolerances
398
- output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"]
399
-
400
- tolerances = {
401
- "mean_vectors_3d_positions": 0.001,
402
- "singular_values_scales": 0.0001,
403
- "quaternions_rotations": 2.0,
404
- "colors_rgb_linear": 0.002,
405
- "opacities_alpha_channel": 0.005,
406
- }
407
-
408
- angular_tolerances = {
409
- "mean": 0.01,
410
- "p99": 0.5,
411
- "max": 10.0,
412
- }
413
-
414
  all_passed = True
415
-
416
- # Additional diagnostics for depth/position analysis
417
- LOGGER.info("=== Depth/Position Statistics ===")
418
- pt_positions = pt_outputs[0].numpy()
419
- onnx_positions = onnx_output_dict.get('mean_vectors_3d_positions', list(onnx_output_dict.values())[0])
420
-
421
- LOGGER.info(f"PyTorch positions - X range: [{pt_positions[..., 0].min():.4f}, {pt_positions[..., 0].max():.4f}], mean: {pt_positions[..., 0].mean():.4f}")
422
- LOGGER.info(f"PyTorch positions - Y range: [{pt_positions[..., 1].min():.4f}, {pt_positions[..., 1].max():.4f}], mean: {pt_positions[..., 1].mean():.4f}")
423
- LOGGER.info(f"PyTorch positions - Z range: [{pt_positions[..., 2].min():.4f}, {pt_positions[..., 2].max():.4f}], mean: {pt_positions[..., 2].mean():.4f}, std: {pt_positions[..., 2].std():.4f}")
424
-
425
- LOGGER.info(f"ONNX positions - X range: [{onnx_positions[..., 0].min():.4f}, {onnx_positions[..., 0].max():.4f}], mean: {onnx_positions[..., 0].mean():.4f}")
426
- LOGGER.info(f"ONNX positions - Y range: [{onnx_positions[..., 1].min():.4f}, {onnx_positions[..., 1].max():.4f}], mean: {onnx_positions[..., 1].mean():.4f}")
427
- LOGGER.info(f"ONNX positions - Z range: [{onnx_positions[..., 2].min():.4f}, {onnx_positions[..., 2].max():.4f}], mean: {onnx_positions[..., 2].mean():.4f}, std: {onnx_positions[..., 2].std():.4f}")
428
-
429
- z_diff = np.abs(pt_positions[..., 2] - onnx_positions[..., 2])
430
- LOGGER.info(f"Z-coordinate difference - max: {z_diff.max():.6f}, mean: {z_diff.mean():.6f}, std: {z_diff.std():.6f}")
431
- LOGGER.info("=================================")
432
-
433
- # Collect validation results for table output
434
- validation_results = []
435
-
436
- for i, name in enumerate(output_names):
437
- pt_output = pt_outputs[i].numpy()
438
-
439
- if name in onnx_output_dict:
440
- onnx_output = onnx_output_dict[name]
441
- else:
442
- if i < len(onnx_output_dict):
443
- onnx_output = list(onnx_output_dict.values())[i]
444
- else:
445
- LOGGER.warning(f"No ONNX output found for {name}")
446
- all_passed = False
447
- continue
448
-
449
  result = {"output": name, "passed": True, "failure_reason": ""}
450
-
451
- # Special handling for quaternions - account for sign ambiguity
452
  if name == "quaternions_rotations":
453
- # Normalize both quaternion outputs to ensure they're unit quaternions
454
- pt_quat_norm = np.linalg.norm(pt_output, axis=-1, keepdims=True)
455
- pt_output_normalized = pt_output / np.clip(pt_quat_norm, 1e-12, None)
456
-
457
- onnx_quat_norm = np.linalg.norm(onnx_output, axis=-1, keepdims=True)
458
- onnx_output_normalized = onnx_output / np.clip(onnx_quat_norm, 1e-12, None)
459
-
460
- # Canonicalize sign: handle edge cases where w ≈ 0
461
- def canonicalize_quaternion(q):
462
- """Canonicalize quaternion to ensure unique representation."""
463
- abs_q = np.abs(q)
464
- max_component_idx = np.argmax(abs_q, axis=-1, keepdims=True)
465
- selector = np.zeros_like(q)
466
- np.put_along_axis(selector, max_component_idx, 1, axis=-1)
467
- max_component_sign = np.sum(q * selector, axis=-1, keepdims=True)
468
- return np.where(max_component_sign < 0, -q, q)
469
-
470
- pt_output_canonical = canonicalize_quaternion(pt_output_normalized)
471
- onnx_output_canonical = canonicalize_quaternion(onnx_output_normalized)
472
-
473
- # Compute differences with canonicalized quaternions
474
- diff = np.abs(pt_output_canonical - onnx_output_canonical)
475
- max_diff = np.max(diff)
476
- mean_diff = np.mean(diff)
477
-
478
- # Angular difference for rotations
479
- dot_products = np.sum(pt_output_canonical * onnx_output_canonical, axis=-1)
480
- dot_products = np.clip(np.abs(dot_products), 0.0, 1.0)
481
- angular_diff_rad = 2 * np.arccos(dot_products)
482
- angular_diff_deg = np.degrees(angular_diff_rad)
483
- max_angular = np.max(angular_diff_deg)
484
- mean_angular = np.mean(angular_diff_deg)
485
- p99_angular = np.percentile(angular_diff_deg, 99)
486
-
487
- quat_passed = True
488
- failure_reasons = []
489
-
490
- if mean_angular > angular_tolerances["mean"]:
491
- quat_passed = False
492
- failure_reasons.append(f"mean angular {mean_angular:.4f}° > {angular_tolerances['mean']:.4f}°")
493
- if p99_angular > angular_tolerances["p99"]:
494
- quat_passed = False
495
- failure_reasons.append(f"p99 angular {p99_angular:.4f}° > {angular_tolerances['p99']:.4f}°")
496
- if max_angular > angular_tolerances["max"]:
497
- quat_passed = False
498
- failure_reasons.append(f"max angular {max_angular:.4f}° > {angular_tolerances['max']:.4f}°")
499
-
500
  result.update({
501
- "max_diff": f"{max_diff:.6f}",
502
- "mean_diff": f"{mean_diff:.6f}",
503
- "p99_diff": f"{np.percentile(diff, 99):.6f}",
504
- "max_angular": f"{max_angular:.4f}",
505
- "mean_angular": f"{mean_angular:.4f}",
506
- "p99_angular": f"{p99_angular:.4f}",
507
- "passed": quat_passed,
508
- "failure_reason": "; ".join(failure_reasons) if failure_reasons else ""
509
  })
510
-
511
- if not quat_passed:
512
  all_passed = False
513
  else:
514
- diff = np.abs(pt_output - onnx_output)
515
- max_diff = np.max(diff)
516
- mean_diff = np.mean(diff)
517
- p99_diff = np.percentile(diff, 99)
518
-
519
- output_tolerance = tolerances.get(name, tolerance)
520
-
521
  result.update({
522
- "max_diff": f"{max_diff:.6f}",
523
- "mean_diff": f"{mean_diff:.6f}",
524
- "p99_diff": f"{p99_diff:.6f}",
525
- "tolerance": f"{output_tolerance:.6f}"
526
  })
527
-
528
- if max_diff > output_tolerance:
529
  result["passed"] = False
530
- result["failure_reason"] = f"max diff {max_diff:.6f} > tolerance {output_tolerance:.6f}"
531
  all_passed = False
532
-
533
- validation_results.append(result)
534
-
535
- # Output validation results as markdown table
536
- if validation_results:
537
- LOGGER.info("\n### Validation Results\n")
538
- LOGGER.info("| Output | Max Diff | Mean Diff | P99 Diff | Angular Diff (°) | Status |")
539
- LOGGER.info("|--------|----------|-----------|----------|------------------|--------|")
540
-
541
- for result in validation_results:
542
- output_name = result["output"].replace("_", " ").title()
543
- max_diff = result["max_diff"]
544
- mean_diff = result["mean_diff"]
545
- p99_diff = result["p99_diff"]
546
-
547
- if "max_angular" in result:
548
- angular_info = f"{result['max_angular']} / {result['mean_angular']} / {result['p99_angular']}"
549
- else:
550
- angular_info = "-"
551
-
552
- status = "✅ PASS" if result["passed"] else f"❌ FAIL"
553
- if result["failure_reason"]:
554
- status += f" ({result['failure_reason']})"
555
-
556
- LOGGER.info(f"| {output_name} | {max_diff} | {mean_diff} | {p99_diff} | {angular_info} | {status} |")
557
-
558
- LOGGER.info("")
559
-
560
  return all_passed
561
 
562
 
563
  def main():
564
- """Main conversion script."""
565
- parser = argparse.ArgumentParser(
566
- description="Convert SHARP PyTorch model to ONNX format"
567
- )
568
- parser.add_argument(
569
- "-c", "--checkpoint",
570
- type=Path,
571
- default=None,
572
- help="Path to PyTorch checkpoint. Downloads default if not provided.",
573
- )
574
- parser.add_argument(
575
- "-o", "--output",
576
- type=Path,
577
- default=Path("sharp.onnx"),
578
- help="Output path for ONNX model (default: sharp.onnx)",
579
- )
580
- parser.add_argument(
581
- "--height",
582
- type=int,
583
- default=1536,
584
- help="Input image height (default: 1536)",
585
- )
586
- parser.add_argument(
587
- "--width",
588
- type=int,
589
- default=1536,
590
- help="Input image width (default: 1536)",
591
- )
592
- parser.add_argument(
593
- "--validate",
594
- action="store_true",
595
- help="Validate ONNX model against PyTorch",
596
- )
597
- parser.add_argument(
598
- "-v", "--verbose",
599
- action="store_true",
600
- help="Enable verbose logging",
601
- )
602
-
603
  args = parser.parse_args()
604
-
605
- # Configure logging
606
- logging.basicConfig(
607
- level=logging.DEBUG if args.verbose else logging.INFO,
608
- format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
609
- )
610
-
611
- # Load PyTorch model
612
  LOGGER.info("Loading SHARP model...")
613
  predictor = load_sharp_model(args.checkpoint)
614
-
615
- # Setup conversion parameters
616
  input_shape = (args.height, args.width)
617
-
618
- # Convert to ONNX
619
  LOGGER.info(f"Converting to ONNX: {args.output}")
620
  convert_to_onnx(predictor, args.output, input_shape=input_shape)
621
  LOGGER.info(f"ONNX model saved to {args.output}")
622
-
623
- # Validate if requested
624
  if args.validate:
625
- if args.output.exists():
626
- validation_passed = validate_onnx_model(args.output, predictor, input_shape)
627
- if validation_passed:
628
- LOGGER.info(" Validation passed!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
629
  else:
630
- LOGGER.error("Validation failed!")
631
  return 1
632
- else:
633
- LOGGER.error(f"ONNX model not found at {args.output} for validation")
634
- return 1
635
-
636
  LOGGER.info("Conversion complete!")
637
  return 0
638
 
 
1
+ """Convert SHARP PyTorch model to ONNX format."""
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
5
  import argparse
6
  import logging
7
+ from dataclasses import dataclass
8
  from pathlib import Path
9
 
10
  import numpy as np
 
12
  import onnxruntime as ort
13
  import torch
14
  import torch.nn as nn
15
+ import torch.nn.functional as F
16
 
 
17
  from sharp.models import PredictorParams, create_predictor
18
  from sharp.models.predictor import RGBGaussianPredictor
19
+ from sharp.utils import io
20
 
21
  LOGGER = logging.getLogger(__name__)
 
22
  DEFAULT_MODEL_URL = "https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt"
23
 
24
+ OUTPUT_NAMES = [
25
+ "mean_vectors_3d_positions",
26
+ "singular_values_scales",
27
+ "quaternions_rotations",
28
+ "colors_rgb_linear",
29
+ "opacities_alpha_channel",
30
+ ]
31
+
32
+
33
+ @dataclass
34
+ class ToleranceConfig:
35
+ random_tolerances: dict = None
36
+ image_tolerances: dict = None
37
+ angular_tolerances_random: dict = None
38
+ angular_tolerances_image: dict = None
39
+
40
+ def __post_init__(self):
41
+ if self.random_tolerances is None:
42
+ self.random_tolerances = {
43
+ "mean_vectors_3d_positions": 0.001,
44
+ "singular_values_scales": 0.0001,
45
+ "quaternions_rotations": 2.0,
46
+ "colors_rgb_linear": 0.002,
47
+ "opacities_alpha_channel": 0.005,
48
+ }
49
+ if self.image_tolerances is None:
50
+ self.image_tolerances = {
51
+ "mean_vectors_3d_positions": 3.5,
52
+ "singular_values_scales": 0.035,
53
+ "quaternions_rotations": 5.0,
54
+ "colors_rgb_linear": 0.01,
55
+ "opacities_alpha_channel": 0.05,
56
+ }
57
+ if self.angular_tolerances_random is None:
58
+ self.angular_tolerances_random = {"mean": 0.01, "p99": 0.1, "p99_9": 1.0, "max": 5.0}
59
+ if self.angular_tolerances_image is None:
60
+ self.angular_tolerances_image = {"mean": 0.2, "p99": 2.0, "p99_9": 5.0, "max": 25.0}
61
+
62
+
63
+ class QuaternionValidator:
64
+ def __init__(self, angular_tolerances=None, enable_outlier_analysis=True, outlier_thresholds=None):
65
+ self.angular_tolerances = angular_tolerances or {"mean": 0.01, "p99": 0.5, "p99_9": 2.0, "max": 15.0}
66
+ self.enable_outlier_analysis = enable_outlier_analysis
67
+ self.outlier_thresholds = outlier_thresholds or [5.0, 10.0, 15.0]
68
+
69
+ @staticmethod
70
+ def canonicalize_quaternion(q):
71
+ abs_q = np.abs(q)
72
+ max_idx = np.argmax(abs_q, axis=-1, keepdims=True)
73
+ selector = np.zeros_like(q)
74
+ np.put_along_axis(selector, max_idx, 1.0, axis=-1)
75
+ max_sign = np.sum(q * selector, axis=-1, keepdims=True)
76
+ return np.where(max_sign < 0, -q, q)
77
+
78
+ @staticmethod
79
+ def compute_angular_differences(quats1, quats2):
80
+ n1 = np.linalg.norm(quats1, axis=-1, keepdims=True)
81
+ n2 = np.linalg.norm(quats2, axis=-1, keepdims=True)
82
+ q1 = quats1 / np.clip(n1, 1e-12, None)
83
+ q2 = quats2 / np.clip(n2, 1e-12, None)
84
+ q1 = QuaternionValidator.canonicalize_quaternion(q1)
85
+ q2 = QuaternionValidator.canonicalize_quaternion(q2)
86
+ dots = np.sum(q1 * q2, axis=-1)
87
+ dots_flipped = np.sum(q1 * (-q2), axis=-1)
88
+ dots = np.maximum(np.abs(dots), np.abs(dots_flipped))
89
+ dots = np.clip(dots, 0.0, 1.0)
90
+ ang_rad = 2.0 * np.arccos(dots)
91
+ ang_deg = np.degrees(ang_rad)
92
+ return ang_deg, {
93
+ "mean": float(np.mean(ang_deg)),
94
+ "std": float(np.std(ang_deg)),
95
+ "max": float(np.max(ang_deg)),
96
+ "p99": float(np.percentile(ang_deg, 99)),
97
+ "p99_9": float(np.percentile(ang_deg, 99.9)),
98
+ }
99
 
100
+ def validate(self, pt_quats, onnx_quats, image_name="Unknown"):
101
+ diff, stats = self.compute_angular_differences(pt_quats, onnx_quats)
102
+ passed = True
103
+ reasons = []
104
+ for k, t in self.angular_tolerances.items():
105
+ if k in stats and stats[k] > t:
106
+ passed = False
107
+ reasons.append(f"{k} angular {stats[k]:.4f} > {t:.4f}")
108
+ return {"image": image_name, "passed": passed, "failure_reasons": reasons, "stats": stats}
109
 
 
 
110
 
111
+ class SharpModelTraceable(nn.Module):
112
+ def __init__(self, predictor):
 
113
  super().__init__()
 
114
  self.init_model = predictor.init_model
115
  self.feature_model = predictor.feature_model
116
  self.monodepth_model = predictor.monodepth_model
 
118
  self.gaussian_composer = predictor.gaussian_composer
119
  self.depth_alignment = predictor.depth_alignment
120
 
121
+ def forward(self, image, disparity_factor):
122
+ monodepth_out = self.monodepth_model(image)
123
+ disp = monodepth_out.disparity
124
+ disp_factor = disparity_factor[:, None, None, None]
125
+ disp_clamped = disp.clamp(min=1e-4, max=1e4)
126
+ depth = disp_factor / disp_clamped
127
+ depth, _ = self.depth_alignment(depth, None, monodepth_out.decoder_features)
128
+ init_out = self.init_model(image, depth)
129
+ feats = self.feature_model(init_out.feature_input, encodings=monodepth_out.output_features)
130
+ deltas = self.prediction_head(feats)
131
+ gaussians = self.gaussian_composer(deltas, init_out.gaussian_base_values, init_out.global_scale)
132
+ quats = gaussians.quaternions
133
+ qnorm = torch.sqrt(torch.clamp(torch.sum(quats * quats, dim=-1, keepdim=True), min=1e-12))
134
+ quats = quats / qnorm
135
+ abs_q = torch.abs(quats)
136
+ max_idx = torch.argmax(abs_q, dim=-1, keepdim=True)
137
+ one_hot = torch.zeros_like(quats)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  one_hot.scatter_(-1, max_idx, 1.0)
139
+ max_sign = torch.sum(quats * one_hot, dim=-1, keepdim=True)
140
+ quats = torch.where(max_sign < 0, -quats, quats).float()
141
+ return (gaussians.mean_vectors, gaussians.singular_values, quats, gaussians.colors, gaussians.opacities)
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
+ def cleanup_onnx_files(onnx_path):
 
 
 
 
 
 
145
  try:
146
  if onnx_path.exists():
 
147
  onnx_path.unlink()
148
+ except Exception:
149
+ pass
150
+ data_path = onnx_path.with_suffix('.onnx.data')
 
 
151
  try:
152
+ if data_path.exists():
153
+ data_path.unlink()
154
+ except Exception:
155
+ pass
 
 
 
 
 
 
 
 
 
 
 
156
 
 
 
 
 
 
 
 
 
157
 
158
+ def cleanup_extraneous_files():
159
+ import glob, os
160
+ patterns = ["onnx__*", "monodepth_*", "feature_model*", "_Constant_*", "_init_model_*"]
161
+ for p in patterns:
162
+ for f in glob.glob(p):
 
163
  try:
164
+ os.remove(f)
165
+ except Exception:
166
+ pass
 
 
 
 
 
 
 
 
 
167
 
 
 
 
168
 
169
+ def load_sharp_model(checkpoint_path=None):
 
 
170
  if checkpoint_path is None:
171
+ LOGGER.info(f"Downloading model from {DEFAULT_MODEL_URL}")
172
  state_dict = torch.hub.load_state_dict_from_url(DEFAULT_MODEL_URL, progress=True)
173
  else:
174
+ LOGGER.info(f"Loading checkpoint from {checkpoint_path}")
175
  state_dict = torch.load(checkpoint_path, weights_only=True, map_location="cpu")
 
 
176
  predictor = create_predictor(PredictorParams())
177
  predictor.load_state_dict(state_dict)
178
  predictor.eval()
 
179
  return predictor
180
 
181
 
182
+ def convert_to_onnx(predictor, output_path, input_shape=(1536, 1536)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  LOGGER.info("Exporting to ONNX format...")
 
 
184
  predictor.depth_alignment.scale_map_estimator = None
185
+ model = SharpModelTraceable(predictor)
186
+ model.eval()
187
+
 
 
 
188
  LOGGER.info("Pre-warming model...")
189
  with torch.no_grad():
190
  for _ in range(3):
191
+ _ = model(torch.randn(1, 3, input_shape[0], input_shape[1]), torch.tensor([1.0]))
192
+
 
 
 
193
  cleanup_onnx_files(output_path)
194
+
195
+ h, w = input_shape
 
196
  torch.manual_seed(42)
197
+ example_image = torch.randn(1, 3, h, w)
198
+ example_disparity = torch.tensor([1.0])
199
+
 
200
  LOGGER.info(f"Exporting to ONNX: {output_path}")
201
+ torch.onnx.export(
202
+ model, (example_image, example_disparity), str(output_path),
203
+ export_params=True, verbose=False,
204
+ input_names=['image', 'disparity_factor'],
205
+ output_names=OUTPUT_NAMES,
206
+ dynamic_axes={name: {1: 'num_gaussians'} for name in OUTPUT_NAMES},
207
+ opset_version=17,
208
+ )
209
+
210
  try:
211
+ model_proto = onnx.load(str(output_path))
212
+ if model_proto.ByteSize() > 2e9:
213
+ LOGGER.info("Model > 2GB, converting to external data format...")
214
+ onnx.save_model(model_proto, str(output_path), save_as_external_data=True,
215
+ all_tensors_to_one_file=True, location=f"{output_path.stem}.onnx.data")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  except Exception as e:
217
+ LOGGER.warning(f"External data format check failed: {e}")
218
+
 
 
219
  try:
220
  onnx.checker.check_model(str(output_path))
221
  LOGGER.info("ONNX model validation passed")
222
  except Exception as e:
223
  LOGGER.warning(f"ONNX model validation skipped: {e}")
224
+
225
+ cleanup_extraneous_files()
 
 
226
  return output_path
227
 
228
 
229
+ def find_onnx_output_key(name, onnx_outputs):
230
+ if name in onnx_outputs:
231
+ return name
232
+ for key in onnx_outputs:
233
+ if name.split('_')[0] in key.lower():
234
+ return key
235
+ return list(onnx_outputs.keys())[OUTPUT_NAMES.index(name) if name in OUTPUT_NAMES else 0]
236
+
237
+
238
+ def load_and_preprocess_image(image_path, target_size=(1536, 1536)):
239
+ LOGGER.info(f"Loading image from {image_path}")
240
+ image_np, orig_size, f_px = io.load_rgb(image_path)
241
+ # Fallback to getting size from array if orig_size is None
242
+ if orig_size is None:
243
+ orig_size = (image_np.shape[1], image_np.shape[0])
244
+ LOGGER.info(f"Original size: {orig_size}, focal: {f_px:.2f}px")
245
+ tensor = torch.from_numpy(image_np).float() / 255.0
246
+ tensor = tensor.permute(2, 0, 1)
247
+ if (orig_size[0], orig_size[1]) != (target_size[1], target_size[0]):
248
+ LOGGER.info(f"Resizing to {target_size[1]}x{target_size[0]}")
249
+ tensor = F.interpolate(tensor.unsqueeze(0), size=target_size, mode="bilinear", align_corners=True).squeeze(0)
250
+ tensor = tensor.unsqueeze(0)
251
+ LOGGER.info(f"Preprocessed shape: {tensor.shape}, range: [{tensor.min():.4f}, {tensor.max():.4f}]")
252
+ return tensor, f_px, orig_size
253
+
254
+
255
+ def run_inference_pair(pytorch_model, onnx_path, image_tensor, disparity_factor=1.0, log_internals=False):
256
+ wrapper = SharpModelTraceable(pytorch_model)
257
+ wrapper.eval()
258
+ image_tensor = image_tensor.float()
259
+ disp_pt = torch.tensor([disparity_factor], dtype=torch.float32)
260
+ with torch.no_grad():
261
+ pt_outputs = wrapper(image_tensor, disp_pt)
262
+
263
+ pt_np = [o.numpy() for o in pt_outputs]
264
+
265
+ session = ort.InferenceSession(str(onnx_path), providers=['CPUExecutionProvider'])
266
+ onnx_inputs = {"image": image_tensor.numpy(), "disparity_factor": np.array([disparity_factor], dtype=np.float32)}
267
+ onnx_raw = session.run(None, onnx_inputs)
268
+
269
+ LOGGER.info(f"ONNX raw outputs count: {len(onnx_raw)}, first shape: {onnx_raw[0].shape if len(onnx_raw) > 0 else 'N/A'}")
270
+
271
+ # Check if outputs are already separated
272
+ if len(onnx_raw) == 5:
273
+ # ONNX returns separate outputs
274
+ onnx_splits = list(onnx_raw)
275
+ elif len(onnx_raw) == 1:
276
+ # ONNX returns concatenated output - split it
277
+ total_size = onnx_raw[0].shape[-1]
278
+ LOGGER.info(f"ONNX single output total size: {total_size}")
279
+
280
+ # Cumulative sizes: positions(3) + scales(3) + quats(4) + colors(3) + opacities(1) = 14
281
+ sizes = [3, 3, 4, 3, 1]
282
+ start = 0
283
+ onnx_splits = []
284
+ for i, size in enumerate(sizes):
285
+ onnx_splits.append(onnx_raw[0][:, :, start:start+size])
286
+ start += size
287
+ else:
288
+ onnx_splits = list(onnx_raw)
289
+
290
+ return pt_np, onnx_splits
291
+
292
+
293
+ def format_validation_table(results, image_name="", include_image=False):
294
+ lines = []
295
+ if include_image:
296
+ lines.append("| Image | Output | Max Diff | Mean Diff | P99 Diff | Status |")
297
+ lines.append("|-------|--------|----------|-----------|----------|--------|")
298
+ for r in results:
299
+ name = r["output"].replace("_", " ").title()
300
+ status = "PASS" if r["passed"] else "FAIL"
301
+ lines.append(f"| {image_name} | {name} | {r['max_diff']} | {r['mean_diff']} | {r['p99_diff']} | {status} |")
302
+ else:
303
+ lines.append("| Output | Max Diff | Mean Diff | P99 Diff | Status |")
304
+ lines.append("|--------|----------|-----------|----------|--------|")
305
+ for r in results:
306
+ name = r["output"].replace("_", " ").title()
307
+ status = "PASS" if r["passed"] else "FAIL"
308
+ lines.append(f"| {name} | {r['max_diff']} | {r['mean_diff']} | {r['p99_diff']} | {status} |")
309
+ return "\n".join(lines)
310
+
311
+
312
+ def validate_with_image(onnx_path, pytorch_model, image_path, input_shape=(1536, 1536)):
313
+ LOGGER.info(f"Validating with image: {image_path}")
314
+ test_image, f_px, (w, h) = load_and_preprocess_image(image_path, input_shape)
315
+ disparity_factor = f_px / w
316
+ LOGGER.info(f"Using disparity_factor = {disparity_factor:.6f}")
317
+
318
+ pt_outputs, onnx_out = run_inference_pair(pytorch_model, onnx_path, test_image, disparity_factor)
319
+
320
+ LOGGER.info(f"PyTorch outputs shapes: {[o.shape for o in pt_outputs]}")
321
+ LOGGER.info(f"ONNX output shapes: {[o.shape for o in onnx_out]}")
322
+
323
+ tolerance_config = ToleranceConfig()
324
+ tolerances = tolerance_config.image_tolerances
325
+ quat_validator = QuaternionValidator(angular_tolerances=tolerance_config.angular_tolerances_image)
326
+
327
+ all_passed = True
328
+ results = []
329
+
330
+ for i, name in enumerate(OUTPUT_NAMES):
331
+ pt_out = pt_outputs[i]
332
+ onnx_output = onnx_out[i]
333
+
334
+ result = {"output": name, "passed": True, "failure_reason": ""}
335
+
336
+ if name == "quaternions_rotations":
337
+ quat_result = quat_validator.validate(pt_out, onnx_output, image_path.name)
338
+ result.update({
339
+ "max_diff": f"{quat_result['stats']['max']:.6f}",
340
+ "mean_diff": f"{quat_result['stats']['mean']:.6f}",
341
+ "p99_diff": f"{quat_result['stats']['p99']:.6f}",
342
+ "passed": quat_result["passed"],
343
+ "failure_reason": "; ".join(quat_result["failure_reasons"]),
344
+ })
345
+ if not quat_result["passed"]:
346
+ all_passed = False
347
+ else:
348
+ diff = np.abs(pt_out - onnx_output)
349
+ tol = tolerances.get(name, 0.01)
350
+ result.update({
351
+ "max_diff": f"{np.max(diff):.6f}",
352
+ "mean_diff": f"{np.mean(diff):.6f}",
353
+ "p99_diff": f"{np.percentile(diff, 99):.6f}",
354
+ })
355
+ if np.max(diff) > tol:
356
+ result["passed"] = False
357
+ result["failure_reason"] = f"max diff {np.max(diff):.6f} > tol {tol:.6f}"
358
+ all_passed = False
359
+
360
+ results.append(result)
361
+
362
+ LOGGER.info(f"\n### Validation Results: {image_path.name}\n")
363
+ LOGGER.info(format_validation_table(results, image_path.name, include_image=True))
364
+ LOGGER.info("")
365
+
366
+ return all_passed
367
 
 
 
 
 
 
368
 
369
+ def validate_onnx_model(onnx_path, pytorch_model, input_shape=(1536, 1536), angular_tolerances=None):
 
 
370
  LOGGER.info("Validating ONNX model against PyTorch...")
 
 
 
 
371
  np.random.seed(42)
372
  torch.manual_seed(42)
373
+
374
+ test_image = np.random.rand(1, 3, input_shape[0], input_shape[1]).astype(np.float32)
375
+ test_disp = np.array([1.0], dtype=np.float32)
376
+
377
+ wrapper = SharpModelTraceable(pytorch_model)
378
+ wrapper.eval()
379
+
 
 
 
 
 
380
  with torch.no_grad():
381
+ pt_out = wrapper(torch.from_numpy(test_image), torch.from_numpy(test_disp))
382
+
383
+ session = ort.InferenceSession(str(onnx_path), providers=['CPUExecutionProvider'])
384
+ onnx_raw = session.run(None, {"image": test_image, "disparity_factor": test_disp})
385
+
386
+ # Use same splitting logic as run_inference_pair
387
+ if len(onnx_raw) == 5:
388
+ onnx_splits = list(onnx_raw)
389
+ elif len(onnx_raw) == 1:
390
+ sizes = [3, 3, 4, 3, 1]
391
+ start = 0
392
+ onnx_splits = []
393
+ for size in sizes:
394
+ onnx_splits.append(onnx_raw[0][:, :, start:start+size])
395
+ start += size
396
+ else:
397
+ onnx_splits = list(onnx_raw)
398
+
399
+ tolerance_config = ToleranceConfig()
400
+ tolerances = tolerance_config.random_tolerances
401
+ quat_validator = QuaternionValidator(angular_tolerances=angular_tolerances or tolerance_config.angular_tolerances_random)
402
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  all_passed = True
404
+ results = []
405
+
406
+ for i, name in enumerate(OUTPUT_NAMES):
407
+ pt_o = pt_out[i].numpy()
408
+ onnx_o = onnx_splits[i]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  result = {"output": name, "passed": True, "failure_reason": ""}
410
+
 
411
  if name == "quaternions_rotations":
412
+ qr = quat_validator.validate(pt_o, onnx_o, "Random")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
  result.update({
414
+ "max_diff": f"{qr['stats']['max']:.6f}",
415
+ "mean_diff": f"{qr['stats']['mean']:.6f}",
416
+ "p99_diff": f"{qr['stats']['p99']:.6f}",
417
+ "passed": qr["passed"],
418
+ "failure_reason": "; ".join(qr["failure_reasons"]),
 
 
 
419
  })
420
+ if not qr["passed"]:
 
421
  all_passed = False
422
  else:
423
+ diff = np.abs(pt_o - onnx_o)
424
+ tol = tolerances.get(name, 0.01)
 
 
 
 
 
425
  result.update({
426
+ "max_diff": f"{np.max(diff):.6f}",
427
+ "mean_diff": f"{np.mean(diff):.6f}",
428
+ "p99_diff": f"{np.percentile(diff, 99):.6f}",
 
429
  })
430
+ if np.max(diff) > tol:
 
431
  result["passed"] = False
432
+ result["failure_reason"] = f"max diff {np.max(diff):.6f} > tol {tol:.6f}"
433
  all_passed = False
434
+
435
+ results.append(result)
436
+
437
+ LOGGER.info("\n### Random Validation Results\n")
438
+ LOGGER.info(format_validation_table(results))
439
+ LOGGER.info("")
440
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  return all_passed
442
 
443
 
444
  def main():
445
+ parser = argparse.ArgumentParser(description="Convert SHARP PyTorch model to ONNX format")
446
+ parser.add_argument("-c", "--checkpoint", type=Path, default=None, help="Path to PyTorch checkpoint")
447
+ parser.add_argument("-o", "--output", type=Path, default=Path("sharp.onnx"), help="Output path for ONNX model")
448
+ parser.add_argument("--height", type=int, default=1536, help="Input image height")
449
+ parser.add_argument("--width", type=int, default=1536, help="Input image width")
450
+ parser.add_argument("--validate", action="store_true", help="Validate ONNX model against PyTorch")
451
+ parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose logging")
452
+ parser.add_argument("--input-image", type=Path, default=None, action="append", help="Path to input image for validation")
453
+ parser.add_argument("--tolerance-mean", type=float, default=None, help="Custom mean angular tolerance in degrees")
454
+ parser.add_argument("--tolerance-p99", type=float, default=None, help="Custom P99 angular tolerance in degrees")
455
+ parser.add_argument("--tolerance-max", type=float, default=None, help="Custom max angular tolerance in degrees")
456
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  args = parser.parse_args()
458
+
459
+ logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO,
460
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
461
+
 
 
 
 
462
  LOGGER.info("Loading SHARP model...")
463
  predictor = load_sharp_model(args.checkpoint)
464
+
 
465
  input_shape = (args.height, args.width)
466
+
 
467
  LOGGER.info(f"Converting to ONNX: {args.output}")
468
  convert_to_onnx(predictor, args.output, input_shape=input_shape)
469
  LOGGER.info(f"ONNX model saved to {args.output}")
470
+
 
471
  if args.validate:
472
+ if args.input_image:
473
+ for img_path in args.input_image:
474
+ if not img_path.exists():
475
+ LOGGER.error(f"Image not found: {img_path}")
476
+ return 1
477
+ passed = validate_with_image(args.output, predictor, img_path, input_shape)
478
+ if not passed:
479
+ LOGGER.error(f"Validation failed for {img_path}")
480
+ return 1
481
+ else:
482
+ angular_tolerances = None
483
+ if args.tolerance_mean or args.tolerance_p99 or args.tolerance_max:
484
+ angular_tolerances = {
485
+ "mean": args.tolerance_mean if args.tolerance_mean else 0.01,
486
+ "p99": args.tolerance_p99 if args.tolerance_p99 else 0.5,
487
+ "p99_9": 2.0,
488
+ "max": args.tolerance_max if args.tolerance_max else 15.0,
489
+ }
490
+ passed = validate_onnx_model(args.output, predictor, input_shape, angular_tolerances=angular_tolerances)
491
+ if passed:
492
+ LOGGER.info("Validation passed!")
493
  else:
494
+ LOGGER.error("Validation failed!")
495
  return 1
496
+
 
 
 
497
  LOGGER.info("Conversion complete!")
498
  return 0
499