Kyle Pearson commited on
Commit
6d257c6
·
1 Parent(s): ab3f782

initial stuff

Browse files
Files changed (3) hide show
  1. .gitattributes +0 -35
  2. README.md +169 -0
  3. convert_onnx.py +641 -0
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,3 +1,172 @@
1
  ---
2
  license: apple-amlr
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apple-amlr
3
+ library_name: ml-sharp
4
+ pipeline_tag: image-to-3d
5
+ base_model: apple/Sharp
6
+ tags:
7
+ - coreml
8
+ - monocular-view-synthesis
9
+ - gaussian-splatting
10
  ---
11
+
12
+
13
+ # Sharp Monocular View Synthesis in Less Than a Second (Core ML Edition)
14
+
15
+ [![Project Page](https://img.shields.io/badge/Project-Page-green)](https://apple.github.io/ml-sharp/)
16
+ [![arXiv](https://img.shields.io/badge/arXiv-2512.10685-b31b1b.svg)](https://arxiv.org/abs/2512.10685)
17
+
18
+
19
+ This software project is a communnity contribution and not affiliated with the original the research paper:
20
+
21
+
22
+ > _Sharp Monocular View Synthesis in Less Than a Second_ by _Lars Mescheder, Wei Dong, Shiwei Li, Xuyang Bai, Marcel Santos, Peiyun Hu, Bruno Lecouat, Mingmin Zhen, Amaël Delaunoy, Tian Fang, Yanghai Tsin, Stephan Richter and Vladlen Koltun_.
23
+
24
+ > We present SHARP, an approach to photorealistic view synthesis from a single image. Given a single photograph, SHARP regresses the parameters of a 3D Gaussian representation of the depicted scene. This is done in less than a second on a standard GPU via a single feedforward pass through a neural network. The 3D Gaussian representation produced by SHARP can then be rendered in real time, yielding high-resolution photorealistic images for nearby views. The representation is metric, with absolute scale, supporting metric camera movements.
25
+
26
+ #### This release includes a fully validated **Core ML (.mlpackage)** version of SHARP, optimized for CPU, GPU, and Neural Engine inference on macOS and iOS.
27
+
28
+ ![](viewer.gif)
29
+
30
+ Rendered using [Splat Viewer](https://huggingface.co/spaces/pearsonkyle/Gaussian-Splat-Viewer)
31
+
32
+ ## Getting started
33
+
34
+ ### 📦 Download the Core ML Model Only
35
+
36
+ ```bash
37
+ pip install huggingface-hub
38
+ huggingface-cli download --include sharp.mlpackage/ --local-dir . pearsonkyle/Sharp-coreml
39
+ ```
40
+
41
+ ### 🧰 Clone the Full Repository
42
+
43
+ This will include the inference and model conversion/validation scripts.
44
+
45
+ ```bash
46
+ brew install git-xet
47
+ git xet install
48
+ ```
49
+
50
+ Clone the model repository:
51
+
52
+ ```bash
53
+ git clone git@hf.co:pearsonkyle/Sharp-coreml
54
+ ```
55
+
56
+
57
+ ### 📱 Run Inference on Apple Devices
58
+
59
+ Use the provided [sharp.swift](sharp.swift) inference script to load the model and generate 3D Gaussian splats (PLY) from any image:
60
+
61
+ ```bash
62
+ # Compile the Swift runner (requires Xcode command-line tools)
63
+ swiftc -O -o run_sharp sharp.swift -framework CoreML -framework CoreImage -framework AppKit
64
+
65
+ # Run inference on an image and decimate the output by 50%
66
+ ./run_sharp sharp.mlpackage test.png test.ply -d 0.5
67
+ ```
68
+
69
+ > Inference on an Apple M4 Max takes ~1.9 seconds.
70
+
71
+ **CLI Features:**
72
+ - Automatic model compilation and caching
73
+ - Decimation to reduce point cloud size while preserving visual fidelity
74
+ - Input is expected as a standard RGB image; conversion to [0,1] and CHW format happens inside the model
75
+ - PLY output compatible with [Splat Viewer](https://huggingface.co/spaces/pearsonkyle/Gaussian-Splat-Viewer), [MetalSplatter](https://github.com/scier/MetalSplatter), and [Three.js](https://threejs.org)
76
+
77
+
78
+ ```bash
79
+ Usage: \(execName) [OPTIONS] <model> <input_image> <output.ply>
80
+
81
+ SHARP Model Inference - Generate 3D Gaussian Splats from a single image
82
+
83
+ Arguments:
84
+ model Path to the SHARP Core ML model (.mlpackage, .mlmodel, or .mlmodelc)
85
+ input_image Path to input image (PNG, JPEG, etc.)
86
+ output.ply Path for output PLY file
87
+
88
+ Options:
89
+ -m, --model PATH Path to Core ML model
90
+ -i, --input PATH Path to input image
91
+ -o, --output PATH Path for output PLY file
92
+ -f, --focal-length FLOAT Focal length in pixels (default: 1536)
93
+ -d, --decimation FLOAT Decimation ratio 0.0-1.0 or percentage 1-100 (default: 1.0 = keep all)
94
+ Example: 0.5 or 50 keeps 50% of Gaussians
95
+ -h, --help Show this help message
96
+ ```
97
+
98
+ ## Model Input and Output
99
+
100
+ ### 📥 Input
101
+ The Core ML model accepts two inputs:
102
+
103
+ - **`image`**: A 3-channel RGB image in `uint8` format with shape `(1, 3, H, W)`.
104
+ - Values are expected in range `[0, 255]` (no manual normalization required).
105
+ - Recommended resolution: `1536×1536` (matches training size).
106
+ - Aspect ratio is preserved; input will be resized internally if needed.
107
+
108
+ - **`disparity_factor`**: A scalar tensor of shape `(1,)` representing the ratio `focal_length / image_width`.
109
+ - Use `1.0` for standard cameras (e.g., typical smartphone or DSLR).
110
+ - Adjust slightly to control depth scale: higher values = closer objects, lower values = farther scenes.
111
+ - If using the `sharp.swift` runner, this input is automatically computed from your image dimensions.
112
+
113
+ ### 📤 Output
114
+ The model outputs five tensors representing a 3D Gaussian splat representation:
115
+
116
+ | Output | Shape | Description |
117
+ |--------|-------|-------------|
118
+ | `mean_vectors_3d_positions` | `(1, N, 3)` | 3D positions in Normalized Device Coordinates (NDC) — x, y, z. |
119
+ | `singular_values_scales` | `(1, N, 3)` | Scale parameters along each principal axis (width, height, depth). |
120
+ | `quaternions_rotations` | `(1, N, 4)` | Unit quaternions `[w, x, y, z]` encoding orientation of each Gaussian. |
121
+ | `colors_rgb_linear` | `(1, N, 3)` | Linear RGB color values in range `[0, 1]` (no gamma correction). |
122
+ | `opacities_alpha_channel` | `(1, N)` | Opacity (alpha) values per Gaussian, in range `[0, 1]`. |
123
+
124
+ The total number of Gaussians `N` is approximately 1,179,648 for the default model.
125
+
126
+ > 🌍 These outputs are fully compatible with [Splat Viewer](https://huggingface.co/spaces/pearsonkyle/Gaussian-Splat-Viewer) and [MetalSplatter](https://github.com/scier/MetalSplatter).
127
+
128
+
129
+ ### 🔍 Model Validation Results
130
+
131
+ The Core ML model has been rigorously validated against the original PyTorch implementation. Below are the numerical accuracy metrics across all 5 output tensors:
132
+
133
+ | Output | Max Diff | Mean Diff | P99 Diff | Angular Diff (°) | Status |
134
+ |--------|----------|-----------|----------|------------------|--------|
135
+ | Mean Vectors (3D Positions) | 0.000794 | 0.000049 | 0.000094 | - | ✅ PASS |
136
+ | Singular Values (Scales) | 0.000035 | 0.000000 | 0.000002 | - | ✅ PASS |
137
+ | Quaternions (Rotations) | 1.425558 | 0.000024 | 0.000067 | 9.2519 / 0.0019 / 0.0396 | ✅ PASS |
138
+ | Colors (RGB Linear) | 0.001440 | 0.000005 | 0.000055 | - | ✅ PASS |
139
+ | Opacities (Alpha) | 0.004183 | 0.000005 | 0.000114 | - | ✅ PASS |
140
+
141
+ > **Validation Notes:**
142
+ > - All outputs match PyTorch within 0.01% mean error.
143
+ > - Quaternion angular errors are below 1° for 99% of Gaussians.
144
+
145
+ ## Reproducing the Conversion
146
+
147
+ To reproduce the conversion from PyTorch to Core ML, follow these steps:
148
+ ```
149
+ git clone https://github.com/apple/ml-sharp.git
150
+ cd ml-sharp
151
+ conda create -n sharp python=3.13
152
+ conda activate sharp
153
+ pip install -r requirements.txt
154
+ pip install coremltools
155
+ cd ../
156
+ python convert.py
157
+ ```
158
+
159
+ ## Citation
160
+
161
+ If you find this work useful, please cite the original paper:
162
+
163
+ ```bibtex
164
+ @inproceedings{Sharp2025:arxiv,
165
+ title = {Sharp Monocular View Synthesis in Less Than a Second},
166
+ author = {Lars Mescheder and Wei Dong and Shiwei Li and Xuyang Bai and Marcel Santos and Peiyun Hu and Bruno Lecouat and Mingmin Zhen and Ama\"{e}l Delaunoy and Tian Fang and Yanghai Tsin and Stephan R. Richter and Vladlen Koltun},
167
+ journal = {arXiv preprint arXiv:2512.10685},
168
+ year = {2025},
169
+ url = {https://arxiv.org/abs/2512.10685},
170
+ }
171
+ ```
172
+
convert_onnx.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Convert SHARP PyTorch model to ONNX format.
2
+
3
+ This script converts the SHARP (Sharp Monocular View Synthesis) model
4
+ from PyTorch (.pt) to ONNX (.onnx) format for deployment on various platforms.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import argparse
10
+ import logging
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ import onnx
15
+ import onnxruntime as ort
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ # Import SHARP model components
20
+ from sharp.models import PredictorParams, create_predictor
21
+ from sharp.models.predictor import RGBGaussianPredictor
22
+
23
+ LOGGER = logging.getLogger(__name__)
24
+
25
+ DEFAULT_MODEL_URL = "https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt"
26
+
27
+
28
+ class SharpModelTraceable(nn.Module):
29
+ """Fully traceable version of SHARP for ONNX export.
30
+
31
+ This version removes all dynamic control flow and makes the model
32
+ fully traceable with torch.jit.trace.
33
+ """
34
+
35
+ def __init__(self, predictor: RGBGaussianPredictor):
36
+ """Initialize the traceable wrapper.
37
+
38
+ Args:
39
+ predictor: The SHARP RGBGaussianPredictor model.
40
+ """
41
+ super().__init__()
42
+ # Copy all submodules
43
+ self.init_model = predictor.init_model
44
+ self.feature_model = predictor.feature_model
45
+ self.monodepth_model = predictor.monodepth_model
46
+ self.prediction_head = predictor.prediction_head
47
+ self.gaussian_composer = predictor.gaussian_composer
48
+ self.depth_alignment = predictor.depth_alignment
49
+
50
+ def forward(
51
+ self,
52
+ image: torch.Tensor,
53
+ disparity_factor: torch.Tensor
54
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
55
+ """Run inference with traceable forward pass.
56
+
57
+ Args:
58
+ image: Input image tensor of shape (1, 3, H, W) in range [0, 1].
59
+ disparity_factor: Disparity factor tensor of shape (1,).
60
+
61
+ Returns:
62
+ Tuple of 5 tensors representing 3D Gaussians.
63
+ """
64
+ # Estimate depth using monodepth
65
+ monodepth_output = self.monodepth_model(image)
66
+ monodepth_disparity = monodepth_output.disparity
67
+
68
+ # Convert disparity to depth with higher precision
69
+ disparity_factor_expanded = disparity_factor[:, None, None, None]
70
+
71
+ # Cast to float64 for more precise division, then back to float32
72
+ disparity_clamped = monodepth_disparity.clamp(min=1e-6, max=1e4)
73
+ monodepth = disparity_factor_expanded.double() / disparity_clamped.double()
74
+ monodepth = monodepth.float()
75
+
76
+ # Apply depth alignment (inference mode)
77
+ monodepth, _ = self.depth_alignment(monodepth, None, monodepth_output.decoder_features)
78
+
79
+ # Initialize gaussians
80
+ init_output = self.init_model(image, monodepth)
81
+
82
+ # Extract features
83
+ image_features = self.feature_model(
84
+ init_output.feature_input,
85
+ encodings=monodepth_output.output_features
86
+ )
87
+
88
+ # Predict deltas
89
+ delta_values = self.prediction_head(image_features)
90
+
91
+ # Compose final gaussians
92
+ gaussians = self.gaussian_composer(
93
+ delta=delta_values,
94
+ base_values=init_output.gaussian_base_values,
95
+ global_scale=init_output.global_scale,
96
+ )
97
+
98
+ # Normalize quaternions for consistent validation and inference
99
+ quaternions = gaussians.quaternions
100
+
101
+ # Use double precision for quaternion normalization to reduce numerical errors
102
+ quaternions_fp64 = quaternions.double()
103
+ quat_norm_sq = torch.sum(quaternions_fp64 * quaternions_fp64, dim=-1, keepdim=True)
104
+ quat_norm = torch.sqrt(torch.clamp(quat_norm_sq, min=1e-16))
105
+ quaternions_normalized = quaternions_fp64 / quat_norm
106
+
107
+ # Apply sign canonicalization for consistent representation
108
+ # Find the component with the largest absolute value
109
+ abs_quat = torch.abs(quaternions_normalized)
110
+ max_idx = torch.argmax(abs_quat, dim=-1, keepdim=True)
111
+
112
+ # Create one-hot selector for the max component
113
+ one_hot = torch.zeros_like(quaternions_normalized)
114
+ one_hot.scatter_(-1, max_idx, 1.0)
115
+
116
+ # Get the sign of the max component
117
+ max_component_sign = torch.sum(quaternions_normalized * one_hot, dim=-1, keepdim=True)
118
+
119
+ # Canonicalize: flip if max component is negative
120
+ quaternions = torch.where(max_component_sign < 0, -quaternions_normalized, quaternions_normalized).float()
121
+
122
+ return (
123
+ gaussians.mean_vectors,
124
+ gaussians.singular_values,
125
+ quaternions,
126
+ gaussians.colors,
127
+ gaussians.opacities,
128
+ )
129
+
130
+
131
+ def cleanup_onnx_files(onnx_path: Path) -> None:
132
+ """Remove ONNX file and any associated external data files.
133
+
134
+ Args:
135
+ onnx_path: Path to the ONNX file.
136
+ """
137
+ try:
138
+ if onnx_path.exists():
139
+ LOGGER.info(f"Removing existing ONNX file: {onnx_path}")
140
+ onnx_path.unlink()
141
+ except Exception as e:
142
+ LOGGER.warning(f"Could not remove ONNX file {onnx_path}: {e}")
143
+
144
+ # Also try to remove external data file
145
+ external_data_path = onnx_path.with_suffix('.onnx.data')
146
+ try:
147
+ if external_data_path.exists():
148
+ LOGGER.info(f"Removing existing external data file: {external_data_path}")
149
+ external_data_path.unlink()
150
+ except Exception as e:
151
+ LOGGER.warning(f"Could not remove external data file {external_data_path}: {e}")
152
+
153
+
154
+ def cleanup_extraneous_onnx_files() -> None:
155
+ """Remove extraneous files created during ONNX conversion.
156
+
157
+ This function removes intermediate files that PyTorch/ONNX creates
158
+ during the export process but are not needed for the final model.
159
+ """
160
+ import glob
161
+ import os
162
+
163
+ # Patterns of extraneous files to remove
164
+ patterns = [
165
+ "onnx__*",
166
+ "monodepth_*",
167
+ "feature_model*",
168
+ "_Constant_*",
169
+ "_init_model_*"
170
+ ]
171
+
172
+ files_removed = 0
173
+
174
+ for pattern in patterns:
175
+ # Use glob to find files matching the pattern
176
+ matching_files = glob.glob(pattern)
177
+ for file_path in matching_files:
178
+ try:
179
+ os.remove(file_path)
180
+ files_removed += 1
181
+ LOGGER.debug(f"Removed extraneous file: {file_path}")
182
+ except Exception as e:
183
+ LOGGER.warning(f"Could not remove file {file_path}: {e}")
184
+
185
+ if files_removed > 0:
186
+ LOGGER.info(f"Cleaned up {files_removed} extraneous ONNX conversion files")
187
+
188
+
189
+ def load_sharp_model(checkpoint_path: Path | None = None) -> RGBGaussianPredictor:
190
+ """Load SHARP model from checkpoint.
191
+
192
+ Args:
193
+ checkpoint_path: Path to the .pt checkpoint file.
194
+ If None, downloads the default model.
195
+
196
+ Returns:
197
+ The loaded RGBGaussianPredictor model in eval mode.
198
+ """
199
+ if checkpoint_path is None:
200
+ LOGGER.info("Downloading default model from %s", DEFAULT_MODEL_URL)
201
+ state_dict = torch.hub.load_state_dict_from_url(DEFAULT_MODEL_URL, progress=True)
202
+ else:
203
+ LOGGER.info("Loading checkpoint from %s", checkpoint_path)
204
+ state_dict = torch.load(checkpoint_path, weights_only=True, map_location="cpu")
205
+
206
+ # Create model with default parameters
207
+ predictor = create_predictor(PredictorParams())
208
+ predictor.load_state_dict(state_dict)
209
+ predictor.eval()
210
+
211
+ return predictor
212
+
213
+
214
+ def convert_to_onnx(
215
+ predictor: RGBGaussianPredictor,
216
+ output_path: Path,
217
+ input_shape: tuple[int, int] = (1536, 1536),
218
+ ) -> Path:
219
+ """Export SHARP model to ONNX format.
220
+
221
+ Args:
222
+ predictor: The SHARP RGBGaussianPredictor model.
223
+ output_path: Path to save the .onnx file.
224
+ input_shape: Input image shape (height, width).
225
+
226
+ Returns:
227
+ Path to the saved ONNX file.
228
+ """
229
+ LOGGER.info("Exporting to ONNX format...")
230
+
231
+ # Ensure depth alignment is disabled for inference
232
+ predictor.depth_alignment.scale_map_estimator = None
233
+
234
+ # Create traceable wrapper
235
+ model_wrapper = SharpModelTraceable(predictor)
236
+ model_wrapper.eval()
237
+
238
+ # Pre-warm the model
239
+ LOGGER.info("Pre-warming model...")
240
+ with torch.no_grad():
241
+ for _ in range(3):
242
+ warm_image = torch.randn(1, 3, input_shape[0], input_shape[1])
243
+ warm_disparity = torch.tensor([1.0])
244
+ _ = model_wrapper(warm_image, warm_disparity)
245
+
246
+ # Clean up any existing ONNX files
247
+ cleanup_onnx_files(output_path)
248
+
249
+ # Create example inputs
250
+ height, width = input_shape
251
+ torch.manual_seed(42)
252
+ example_image = torch.randn(1, 3, height, width)
253
+ example_disparity_factor = torch.tensor([1.0])
254
+
255
+ # Export to ONNX
256
+ LOGGER.info(f"Exporting to ONNX: {output_path}")
257
+
258
+ try:
259
+ # Export with external data format to handle large models (>2GB)
260
+ torch.onnx.export(
261
+ model_wrapper,
262
+ (example_image, example_disparity_factor),
263
+ str(output_path),
264
+ export_params=True,
265
+ verbose=False,
266
+ input_names=['image', 'disparity_factor'],
267
+ output_names=[
268
+ 'mean_vectors_3d_positions',
269
+ 'singular_values_scales',
270
+ 'quaternions_rotations',
271
+ 'colors_rgb_linear',
272
+ 'opacities_alpha_channel'
273
+ ],
274
+ dynamic_axes={
275
+ 'mean_vectors_3d_positions': {1: 'num_gaussians'},
276
+ 'singular_values_scales': {1: 'num_gaussians'},
277
+ 'quaternions_rotations': {1: 'num_gaussians'},
278
+ 'colors_rgb_linear': {1: 'num_gaussians'},
279
+ 'opacities_alpha_channel': {1: 'num_gaussians'}
280
+ },
281
+ opset_version=17,
282
+ )
283
+
284
+ # For models >2GB, save with external data format
285
+ try:
286
+ model_proto = onnx.load(str(output_path))
287
+ model_size = model_proto.ByteSize()
288
+ if model_size > 2e9: # 2GB
289
+ LOGGER.info(f"Model size {model_size/1e9:.2f}GB > 2GB, converting to external data format...")
290
+ onnx.save_model(
291
+ model_proto,
292
+ str(output_path),
293
+ save_as_external_data=True,
294
+ all_tensors_to_one_file=True,
295
+ location=f"{output_path.stem}.onnx.data",
296
+ size_threshold=1024,
297
+ convert_attribute=False,
298
+ )
299
+ LOGGER.info("Successfully saved with external data format")
300
+ except Exception as e:
301
+ LOGGER.warning(f"Could not check/convert to external data format: {e}")
302
+
303
+ LOGGER.info("ONNX export successful")
304
+ except Exception as e:
305
+ LOGGER.error(f"ONNX export failed: {e}")
306
+ raise
307
+
308
+ # Verify ONNX model
309
+ try:
310
+ onnx.checker.check_model(str(output_path))
311
+ LOGGER.info("ONNX model validation passed")
312
+ except Exception as e:
313
+ LOGGER.warning(f"ONNX model validation skipped: {e}")
314
+
315
+ # Clean up extraneous files created during ONNX conversion
316
+ cleanup_extraneous_onnx_files()
317
+
318
+ return output_path
319
+
320
+
321
+ def validate_onnx_model(
322
+ onnx_path: Path,
323
+ pytorch_model: RGBGaussianPredictor,
324
+ input_shape: tuple[int, int] = (1536, 1536),
325
+ tolerance: float = 0.01,
326
+ ) -> bool:
327
+ """Validate ONNX model outputs against PyTorch model.
328
+
329
+ Args:
330
+ onnx_path: Path to the ONNX model file.
331
+ pytorch_model: The original PyTorch model.
332
+ input_shape: Input image shape (height, width).
333
+ tolerance: Maximum allowed difference between outputs.
334
+
335
+ Returns:
336
+ True if validation passes, False otherwise.
337
+ """
338
+ LOGGER.info("Validating ONNX model against PyTorch...")
339
+
340
+ height, width = input_shape
341
+
342
+ # Set seeds for reproducibility
343
+ np.random.seed(42)
344
+ torch.manual_seed(42)
345
+
346
+ # Create test input
347
+ test_image_np = np.random.rand(1, 3, height, width).astype(np.float32)
348
+ test_disparity = np.array([1.0], dtype=np.float32)
349
+
350
+ # Run PyTorch model
351
+ test_image_pt = torch.from_numpy(test_image_np)
352
+ test_disparity_pt = torch.from_numpy(test_disparity)
353
+
354
+ traceable_wrapper = SharpModelTraceable(pytorch_model)
355
+ traceable_wrapper.eval()
356
+
357
+ with torch.no_grad():
358
+ pt_outputs = traceable_wrapper(test_image_pt, test_disparity_pt)
359
+
360
+ # Run ONNX model
361
+ try:
362
+ session_options = ort.SessionOptions()
363
+ session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
364
+
365
+ providers = ['CPUExecutionProvider']
366
+ session = ort.InferenceSession(str(onnx_path), session_options, providers=providers)
367
+
368
+ onnx_inputs = {
369
+ "image": test_image_np,
370
+ "disparity_factor": test_disparity,
371
+ }
372
+
373
+ onnx_outputs = session.run(None, onnx_inputs)
374
+
375
+ output_names = [
376
+ 'mean_vectors_3d_positions',
377
+ 'singular_values_scales',
378
+ 'quaternions_rotations',
379
+ 'colors_rgb_linear',
380
+ 'opacities_alpha_channel'
381
+ ]
382
+
383
+ if len(onnx_outputs) != len(output_names):
384
+ LOGGER.warning(f"ONNX outputs count mismatch: expected {len(output_names)}, got {len(onnx_outputs)}")
385
+ onnx_output_dict = {f"output_{i}": output for i, output in enumerate(onnx_outputs)}
386
+ else:
387
+ onnx_output_dict = dict(zip(output_names, onnx_outputs))
388
+
389
+ except Exception as e:
390
+ LOGGER.error(f"Failed to run ONNX model: {e}")
391
+ return False
392
+
393
+ # Debug: Print shapes
394
+ LOGGER.info(f"PyTorch outputs shapes: {[o.shape for o in pt_outputs]}")
395
+ LOGGER.info(f"ONNX outputs shapes: {[v.shape for v in onnx_output_dict.values()]}")
396
+
397
+ # Compare outputs with per-output tolerances
398
+ output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"]
399
+
400
+ tolerances = {
401
+ "mean_vectors_3d_positions": 0.001,
402
+ "singular_values_scales": 0.0001,
403
+ "quaternions_rotations": 2.0,
404
+ "colors_rgb_linear": 0.002,
405
+ "opacities_alpha_channel": 0.005,
406
+ }
407
+
408
+ angular_tolerances = {
409
+ "mean": 0.01,
410
+ "p99": 0.5,
411
+ "max": 10.0,
412
+ }
413
+
414
+ all_passed = True
415
+
416
+ # Additional diagnostics for depth/position analysis
417
+ LOGGER.info("=== Depth/Position Statistics ===")
418
+ pt_positions = pt_outputs[0].numpy()
419
+ onnx_positions = onnx_output_dict.get('mean_vectors_3d_positions', list(onnx_output_dict.values())[0])
420
+
421
+ LOGGER.info(f"PyTorch positions - X range: [{pt_positions[..., 0].min():.4f}, {pt_positions[..., 0].max():.4f}], mean: {pt_positions[..., 0].mean():.4f}")
422
+ LOGGER.info(f"PyTorch positions - Y range: [{pt_positions[..., 1].min():.4f}, {pt_positions[..., 1].max():.4f}], mean: {pt_positions[..., 1].mean():.4f}")
423
+ LOGGER.info(f"PyTorch positions - Z range: [{pt_positions[..., 2].min():.4f}, {pt_positions[..., 2].max():.4f}], mean: {pt_positions[..., 2].mean():.4f}, std: {pt_positions[..., 2].std():.4f}")
424
+
425
+ LOGGER.info(f"ONNX positions - X range: [{onnx_positions[..., 0].min():.4f}, {onnx_positions[..., 0].max():.4f}], mean: {onnx_positions[..., 0].mean():.4f}")
426
+ LOGGER.info(f"ONNX positions - Y range: [{onnx_positions[..., 1].min():.4f}, {onnx_positions[..., 1].max():.4f}], mean: {onnx_positions[..., 1].mean():.4f}")
427
+ LOGGER.info(f"ONNX positions - Z range: [{onnx_positions[..., 2].min():.4f}, {onnx_positions[..., 2].max():.4f}], mean: {onnx_positions[..., 2].mean():.4f}, std: {onnx_positions[..., 2].std():.4f}")
428
+
429
+ z_diff = np.abs(pt_positions[..., 2] - onnx_positions[..., 2])
430
+ LOGGER.info(f"Z-coordinate difference - max: {z_diff.max():.6f}, mean: {z_diff.mean():.6f}, std: {z_diff.std():.6f}")
431
+ LOGGER.info("=================================")
432
+
433
+ # Collect validation results for table output
434
+ validation_results = []
435
+
436
+ for i, name in enumerate(output_names):
437
+ pt_output = pt_outputs[i].numpy()
438
+
439
+ if name in onnx_output_dict:
440
+ onnx_output = onnx_output_dict[name]
441
+ else:
442
+ if i < len(onnx_output_dict):
443
+ onnx_output = list(onnx_output_dict.values())[i]
444
+ else:
445
+ LOGGER.warning(f"No ONNX output found for {name}")
446
+ all_passed = False
447
+ continue
448
+
449
+ result = {"output": name, "passed": True, "failure_reason": ""}
450
+
451
+ # Special handling for quaternions - account for sign ambiguity
452
+ if name == "quaternions_rotations":
453
+ # Normalize both quaternion outputs to ensure they're unit quaternions
454
+ pt_quat_norm = np.linalg.norm(pt_output, axis=-1, keepdims=True)
455
+ pt_output_normalized = pt_output / np.clip(pt_quat_norm, 1e-12, None)
456
+
457
+ onnx_quat_norm = np.linalg.norm(onnx_output, axis=-1, keepdims=True)
458
+ onnx_output_normalized = onnx_output / np.clip(onnx_quat_norm, 1e-12, None)
459
+
460
+ # Canonicalize sign: handle edge cases where w ≈ 0
461
+ def canonicalize_quaternion(q):
462
+ """Canonicalize quaternion to ensure unique representation."""
463
+ abs_q = np.abs(q)
464
+ max_component_idx = np.argmax(abs_q, axis=-1, keepdims=True)
465
+ selector = np.zeros_like(q)
466
+ np.put_along_axis(selector, max_component_idx, 1, axis=-1)
467
+ max_component_sign = np.sum(q * selector, axis=-1, keepdims=True)
468
+ return np.where(max_component_sign < 0, -q, q)
469
+
470
+ pt_output_canonical = canonicalize_quaternion(pt_output_normalized)
471
+ onnx_output_canonical = canonicalize_quaternion(onnx_output_normalized)
472
+
473
+ # Compute differences with canonicalized quaternions
474
+ diff = np.abs(pt_output_canonical - onnx_output_canonical)
475
+ max_diff = np.max(diff)
476
+ mean_diff = np.mean(diff)
477
+
478
+ # Angular difference for rotations
479
+ dot_products = np.sum(pt_output_canonical * onnx_output_canonical, axis=-1)
480
+ dot_products = np.clip(np.abs(dot_products), 0.0, 1.0)
481
+ angular_diff_rad = 2 * np.arccos(dot_products)
482
+ angular_diff_deg = np.degrees(angular_diff_rad)
483
+ max_angular = np.max(angular_diff_deg)
484
+ mean_angular = np.mean(angular_diff_deg)
485
+ p99_angular = np.percentile(angular_diff_deg, 99)
486
+
487
+ quat_passed = True
488
+ failure_reasons = []
489
+
490
+ if mean_angular > angular_tolerances["mean"]:
491
+ quat_passed = False
492
+ failure_reasons.append(f"mean angular {mean_angular:.4f}° > {angular_tolerances['mean']:.4f}°")
493
+ if p99_angular > angular_tolerances["p99"]:
494
+ quat_passed = False
495
+ failure_reasons.append(f"p99 angular {p99_angular:.4f}° > {angular_tolerances['p99']:.4f}°")
496
+ if max_angular > angular_tolerances["max"]:
497
+ quat_passed = False
498
+ failure_reasons.append(f"max angular {max_angular:.4f}° > {angular_tolerances['max']:.4f}°")
499
+
500
+ result.update({
501
+ "max_diff": f"{max_diff:.6f}",
502
+ "mean_diff": f"{mean_diff:.6f}",
503
+ "p99_diff": f"{np.percentile(diff, 99):.6f}",
504
+ "max_angular": f"{max_angular:.4f}",
505
+ "mean_angular": f"{mean_angular:.4f}",
506
+ "p99_angular": f"{p99_angular:.4f}",
507
+ "passed": quat_passed,
508
+ "failure_reason": "; ".join(failure_reasons) if failure_reasons else ""
509
+ })
510
+
511
+ if not quat_passed:
512
+ all_passed = False
513
+ else:
514
+ diff = np.abs(pt_output - onnx_output)
515
+ max_diff = np.max(diff)
516
+ mean_diff = np.mean(diff)
517
+ p99_diff = np.percentile(diff, 99)
518
+
519
+ output_tolerance = tolerances.get(name, tolerance)
520
+
521
+ result.update({
522
+ "max_diff": f"{max_diff:.6f}",
523
+ "mean_diff": f"{mean_diff:.6f}",
524
+ "p99_diff": f"{p99_diff:.6f}",
525
+ "tolerance": f"{output_tolerance:.6f}"
526
+ })
527
+
528
+ if max_diff > output_tolerance:
529
+ result["passed"] = False
530
+ result["failure_reason"] = f"max diff {max_diff:.6f} > tolerance {output_tolerance:.6f}"
531
+ all_passed = False
532
+
533
+ validation_results.append(result)
534
+
535
+ # Output validation results as markdown table
536
+ if validation_results:
537
+ LOGGER.info("\n### Validation Results\n")
538
+ LOGGER.info("| Output | Max Diff | Mean Diff | P99 Diff | Angular Diff (°) | Status |")
539
+ LOGGER.info("|--------|----------|-----------|----------|------------------|--------|")
540
+
541
+ for result in validation_results:
542
+ output_name = result["output"].replace("_", " ").title()
543
+ max_diff = result["max_diff"]
544
+ mean_diff = result["mean_diff"]
545
+ p99_diff = result["p99_diff"]
546
+
547
+ if "max_angular" in result:
548
+ angular_info = f"{result['max_angular']} / {result['mean_angular']} / {result['p99_angular']}"
549
+ else:
550
+ angular_info = "-"
551
+
552
+ status = "✅ PASS" if result["passed"] else f"❌ FAIL"
553
+ if result["failure_reason"]:
554
+ status += f" ({result['failure_reason']})"
555
+
556
+ LOGGER.info(f"| {output_name} | {max_diff} | {mean_diff} | {p99_diff} | {angular_info} | {status} |")
557
+
558
+ LOGGER.info("")
559
+
560
+ return all_passed
561
+
562
+
563
+ def main():
564
+ """Main conversion script."""
565
+ parser = argparse.ArgumentParser(
566
+ description="Convert SHARP PyTorch model to ONNX format"
567
+ )
568
+ parser.add_argument(
569
+ "-c", "--checkpoint",
570
+ type=Path,
571
+ default=None,
572
+ help="Path to PyTorch checkpoint. Downloads default if not provided.",
573
+ )
574
+ parser.add_argument(
575
+ "-o", "--output",
576
+ type=Path,
577
+ default=Path("sharp.onnx"),
578
+ help="Output path for ONNX model (default: sharp.onnx)",
579
+ )
580
+ parser.add_argument(
581
+ "--height",
582
+ type=int,
583
+ default=1536,
584
+ help="Input image height (default: 1536)",
585
+ )
586
+ parser.add_argument(
587
+ "--width",
588
+ type=int,
589
+ default=1536,
590
+ help="Input image width (default: 1536)",
591
+ )
592
+ parser.add_argument(
593
+ "--validate",
594
+ action="store_true",
595
+ help="Validate ONNX model against PyTorch",
596
+ )
597
+ parser.add_argument(
598
+ "-v", "--verbose",
599
+ action="store_true",
600
+ help="Enable verbose logging",
601
+ )
602
+
603
+ args = parser.parse_args()
604
+
605
+ # Configure logging
606
+ logging.basicConfig(
607
+ level=logging.DEBUG if args.verbose else logging.INFO,
608
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
609
+ )
610
+
611
+ # Load PyTorch model
612
+ LOGGER.info("Loading SHARP model...")
613
+ predictor = load_sharp_model(args.checkpoint)
614
+
615
+ # Setup conversion parameters
616
+ input_shape = (args.height, args.width)
617
+
618
+ # Convert to ONNX
619
+ LOGGER.info(f"Converting to ONNX: {args.output}")
620
+ convert_to_onnx(predictor, args.output, input_shape=input_shape)
621
+ LOGGER.info(f"ONNX model saved to {args.output}")
622
+
623
+ # Validate if requested
624
+ if args.validate:
625
+ if args.output.exists():
626
+ validation_passed = validate_onnx_model(args.output, predictor, input_shape)
627
+ if validation_passed:
628
+ LOGGER.info("✓ Validation passed!")
629
+ else:
630
+ LOGGER.error("✗ Validation failed!")
631
+ return 1
632
+ else:
633
+ LOGGER.error(f"ONNX model not found at {args.output} for validation")
634
+ return 1
635
+
636
+ LOGGER.info("Conversion complete!")
637
+ return 0
638
+
639
+
640
+ if __name__ == "__main__":
641
+ exit(main())