descentbrine pearsonkyle commited on
Commit
71fa15d
·
0 Parent(s):

Duplicate from pearsonkyle/Sharp-coreml

Browse files

Co-authored-by: Kyle Pearson <pearsonkyle@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ test.ply filter=lfs diff=lfs merge=lfs -text
37
+ test.gif filter=lfs diff=lfs merge=lfs -text
38
+ test.png filter=lfs diff=lfs merge=lfs -text
39
+ sharp.mlpackage/ filter=lfs diff=lfs merge=lfs -text
40
+ viewer.gif filter=lfs diff=lfs merge=lfs -text
41
+ sharp.mlpackage filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ __pycache__/
3
+ onnx__*
4
+ monodepth_*
5
+ feature_model*
6
+ _Constant_*
7
+ _init_model_*
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "ml-sharp"]
2
+ path = ml-sharp
3
+ url = https://github.com/apple/ml-sharp
README.md ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apple-amlr
3
+ library_name: ml-sharp
4
+ pipeline_tag: image-to-3d
5
+ base_model: apple/Sharp
6
+ tags:
7
+ - coreml
8
+ - monocular-view-synthesis
9
+ - gaussian-splatting
10
+ ---
11
+
12
+
13
+ # Sharp Monocular View Synthesis in Less Than a Second (Core ML Edition)
14
+
15
+ [![Project Page](https://img.shields.io/badge/Project-Page-green)](https://apple.github.io/ml-sharp/)
16
+ [![arXiv](https://img.shields.io/badge/arXiv-2512.10685-b31b1b.svg)](https://arxiv.org/abs/2512.10685)
17
+
18
+
19
+ This software project is a communnity contribution and not affiliated with the original the research paper:
20
+
21
+
22
+ > _Sharp Monocular View Synthesis in Less Than a Second_ by _Lars Mescheder, Wei Dong, Shiwei Li, Xuyang Bai, Marcel Santos, Peiyun Hu, Bruno Lecouat, Mingmin Zhen, Amaël Delaunoy, Tian Fang, Yanghai Tsin, Stephan Richter and Vladlen Koltun_.
23
+
24
+ > We present SHARP, an approach to photorealistic view synthesis from a single image. Given a single photograph, SHARP regresses the parameters of a 3D Gaussian representation of the depicted scene. This is done in less than a second on a standard GPU via a single feedforward pass through a neural network. The 3D Gaussian representation produced by SHARP can then be rendered in real time, yielding high-resolution photorealistic images for nearby views. The representation is metric, with absolute scale, supporting metric camera movements.
25
+
26
+ #### This release includes a fully validated **Core ML (.mlpackage)** version of SHARP, optimized for CPU, GPU, and Neural Engine inference on macOS and iOS.
27
+
28
+ ![](viewer.gif)
29
+
30
+ Rendered using [Splat Viewer](https://huggingface.co/spaces/pearsonkyle/Gaussian-Splat-Viewer)
31
+
32
+ ## Getting started
33
+
34
+ ### 📦 Download the Core ML Model Only
35
+
36
+ ```bash
37
+ pip install huggingface-hub
38
+ huggingface-cli download --include sharp.mlpackage/ --local-dir . pearsonkyle/Sharp-coreml
39
+ ```
40
+
41
+ ### 🧰 Clone the Full Repository
42
+
43
+ This will include the inference and model conversion/validation scripts.
44
+
45
+ ```bash
46
+ brew install git-xet
47
+ git xet install
48
+ ```
49
+
50
+ Clone the model repository:
51
+
52
+ ```bash
53
+ git clone git@hf.co:pearsonkyle/Sharp-coreml
54
+ ```
55
+
56
+
57
+ ### 📱 Run Inference on Apple Devices
58
+
59
+ Use the provided [sharp.swift](sharp.swift) inference script to load the model and generate 3D Gaussian splats (PLY) from any image:
60
+
61
+ ```bash
62
+ # Compile the Swift runner (requires Xcode command-line tools)
63
+ swiftc -O -o run_sharp sharp.swift -framework CoreML -framework CoreImage -framework AppKit
64
+
65
+ # Run inference on an image and decimate the output by 50%
66
+ ./run_sharp sharp.mlpackage test.png test.ply -d 0.5
67
+ ```
68
+
69
+ > Inference on an Apple M4 Max takes ~1.9 seconds.
70
+
71
+ **CLI Features:**
72
+ - Automatic model compilation and caching
73
+ - Decimation to reduce point cloud size while preserving visual fidelity
74
+ - Input is expected as a standard RGB image; conversion to [0,1] and CHW format happens inside the model
75
+ - PLY output compatible with [Splat Viewer](https://huggingface.co/spaces/pearsonkyle/Gaussian-Splat-Viewer), [MetalSplatter](https://github.com/scier/MetalSplatter), and [Three.js](https://threejs.org)
76
+
77
+
78
+ ```bash
79
+ Usage: \(execName) [OPTIONS] <model> <input_image> <output.ply>
80
+
81
+ SHARP Model Inference - Generate 3D Gaussian Splats from a single image
82
+
83
+ Arguments:
84
+ model Path to the SHARP Core ML model (.mlpackage, .mlmodel, or .mlmodelc)
85
+ input_image Path to input image (PNG, JPEG, etc.)
86
+ output.ply Path for output PLY file
87
+
88
+ Options:
89
+ -m, --model PATH Path to Core ML model
90
+ -i, --input PATH Path to input image
91
+ -o, --output PATH Path for output PLY file
92
+ -f, --focal-length FLOAT Focal length in pixels (default: 1536)
93
+ -d, --decimation FLOAT Decimation ratio 0.0-1.0 or percentage 1-100 (default: 1.0 = keep all)
94
+ Example: 0.5 or 50 keeps 50% of Gaussians
95
+ -h, --help Show this help message
96
+ ```
97
+
98
+ ## Model Input and Output
99
+
100
+ ### 📥 Input
101
+ The Core ML model accepts two inputs:
102
+
103
+ - **`image`**: A 3-channel RGB image in `uint8` format with shape `(1, 3, H, W)`.
104
+ - Values are expected in range `[0, 255]` (no manual normalization required).
105
+ - Recommended resolution: `1536×1536` (matches training size).
106
+ - Aspect ratio is preserved; input will be resized internally if needed.
107
+
108
+ - **`disparity_factor`**: A scalar tensor of shape `(1,)` representing the ratio `focal_length / image_width`.
109
+ - Use `1.0` for standard cameras (e.g., typical smartphone or DSLR).
110
+ - Adjust slightly to control depth scale: higher values = closer objects, lower values = farther scenes.
111
+ - If using the `sharp.swift` runner, this input is automatically computed from your image dimensions.
112
+
113
+ ### 📤 Output
114
+ The model outputs five tensors representing a 3D Gaussian splat representation:
115
+
116
+ | Output | Shape | Description |
117
+ |--------|-------|-------------|
118
+ | `mean_vectors_3d_positions` | `(1, N, 3)` | 3D positions in Normalized Device Coordinates (NDC) — x, y, z. |
119
+ | `singular_values_scales` | `(1, N, 3)` | Scale parameters along each principal axis (width, height, depth). |
120
+ | `quaternions_rotations` | `(1, N, 4)` | Unit quaternions `[w, x, y, z]` encoding orientation of each Gaussian. |
121
+ | `colors_rgb_linear` | `(1, N, 3)` | Linear RGB color values in range `[0, 1]` (no gamma correction). |
122
+ | `opacities_alpha_channel` | `(1, N)` | Opacity (alpha) values per Gaussian, in range `[0, 1]`. |
123
+
124
+ The total number of Gaussians `N` is approximately 1,179,648 for the default model.
125
+
126
+ > 🌍 These outputs are fully compatible with [Splat Viewer](https://huggingface.co/spaces/pearsonkyle/Gaussian-Splat-Viewer) and [MetalSplatter](https://github.com/scier/MetalSplatter).
127
+
128
+
129
+ ### 🔍 Model Validation Results
130
+
131
+ The Core ML model has been rigorously validated against the original PyTorch implementation. Below are the numerical accuracy metrics across all 5 output tensors:
132
+
133
+ | Output | Max Diff | Mean Diff | P99 Diff | Angular Diff (°) | Status |
134
+ |--------|----------|-----------|----------|------------------|--------|
135
+ | Mean Vectors (3D Positions) | 0.000794 | 0.000049 | 0.000094 | - | ✅ PASS |
136
+ | Singular Values (Scales) | 0.000035 | 0.000000 | 0.000002 | - | ✅ PASS |
137
+ | Quaternions (Rotations) | 1.425558 | 0.000024 | 0.000067 | 9.2519 / 0.0019 / 0.0396 | ✅ PASS |
138
+ | Colors (RGB Linear) | 0.001440 | 0.000005 | 0.000055 | - | ✅ PASS |
139
+ | Opacities (Alpha) | 0.004183 | 0.000005 | 0.000114 | - | ✅ PASS |
140
+
141
+ > **Validation Notes:**
142
+ > - All outputs match PyTorch within 0.01% mean error.
143
+ > - Quaternion angular errors are below 1° for 99% of Gaussians.
144
+
145
+ ## Reproducing the Conversion
146
+
147
+ To reproduce the conversion from PyTorch to Core ML, follow these steps:
148
+ ```
149
+ git clone https://github.com/apple/ml-sharp.git
150
+ cd ml-sharp
151
+ conda create -n sharp python=3.13
152
+ conda activate sharp
153
+ pip install -r requirements.txt
154
+ pip install coremltools
155
+ cd ../
156
+ python convert.py
157
+ ```
158
+
159
+ ## Citation
160
+
161
+ If you find this work useful, please cite the original paper:
162
+
163
+ ```bibtex
164
+ @inproceedings{Sharp2025:arxiv,
165
+ title = {Sharp Monocular View Synthesis in Less Than a Second},
166
+ author = {Lars Mescheder and Wei Dong and Shiwei Li and Xuyang Bai and Marcel Santos and Peiyun Hu and Bruno Lecouat and Mingmin Zhen and Ama\"{e}l Delaunoy and Tian Fang and Yanghai Tsin and Stephan R. Richter and Vladlen Koltun},
167
+ journal = {arXiv preprint arXiv:2512.10685},
168
+ year = {2025},
169
+ url = {https://arxiv.org/abs/2512.10685},
170
+ }
171
+ ```
172
+
convert.py ADDED
@@ -0,0 +1,1620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Convert SHARP PyTorch model to Core ML .mlmodel format.
2
+
3
+ This script converts the SHARP (Sharp Monocular View Synthesis) model
4
+ from PyTorch (.pt) to Core ML (.mlmodel) format for deployment on Apple devices.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import argparse
10
+ import logging
11
+ from dataclasses import dataclass
12
+ from pathlib import Path
13
+ from typing import Any
14
+
15
+ import coremltools as ct
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn as nn
19
+ from PIL import Image
20
+
21
+ # Import SHARP model components
22
+ from sharp.models import PredictorParams, create_predictor
23
+ from sharp.models.predictor import RGBGaussianPredictor
24
+ from sharp.utils import io
25
+
26
+ LOGGER = logging.getLogger(__name__)
27
+
28
+ DEFAULT_MODEL_URL = "https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt"
29
+
30
+ # ============================================================================
31
+ # Constants & Configuration
32
+ # ============================================================================
33
+
34
+ # Output names for Core ML model
35
+ OUTPUT_NAMES = [
36
+ "mean_vectors_3d_positions",
37
+ "singular_values_scales",
38
+ "quaternions_rotations",
39
+ "colors_rgb_linear",
40
+ "opacities_alpha_channel",
41
+ ]
42
+
43
+ # Output descriptions for Core ML metadata
44
+ OUTPUT_DESCRIPTIONS = {
45
+ "mean_vectors_3d_positions": (
46
+ "3D positions of Gaussian splats in normalized device coordinates (NDC). "
47
+ "Shape: (1, N, 3), where N is the number of Gaussians."
48
+ ),
49
+ "singular_values_scales": (
50
+ "Scale factors for each Gaussian along its principal axes. "
51
+ "Represents size and anisotropy. Shape: (1, N, 3)."
52
+ ),
53
+ "quaternions_rotations": (
54
+ "Rotation of each Gaussian as a unit quaternion [w, x, y, z]. "
55
+ "Used to orient the ellipsoid. Shape: (1, N, 4)."
56
+ ),
57
+ "colors_rgb_linear": (
58
+ "RGB color values in linear RGB space (not gamma-corrected). "
59
+ "Shape: (1, N, 3), with range [0, 1]."
60
+ ),
61
+ "opacities_alpha_channel": (
62
+ "Opacity value per Gaussian (alpha channel), used for blending. "
63
+ "Shape: (1, N), where values are in [0, 1]."
64
+ ),
65
+ }
66
+
67
+
68
+ @dataclass
69
+ class ToleranceConfig:
70
+ """Tolerance configuration for validation."""
71
+
72
+ # Tolerances for random validation (tight)
73
+ random_tolerances: dict[str, float] = None
74
+
75
+ # Tolerances for real image validation (more lenient)
76
+ image_tolerances: dict[str, float] = None
77
+
78
+ # Angular tolerances for quaternions (in degrees)
79
+ angular_tolerances_random: dict[str, float] = None
80
+ angular_tolerances_image: dict[str, float] = None
81
+
82
+ def __post_init__(self):
83
+ if self.random_tolerances is None:
84
+ self.random_tolerances = {
85
+ "mean_vectors_3d_positions": 0.001,
86
+ "singular_values_scales": 0.0001,
87
+ "quaternions_rotations": 2.0,
88
+ "colors_rgb_linear": 0.002,
89
+ "opacities_alpha_channel": 0.005,
90
+ }
91
+
92
+ if self.image_tolerances is None:
93
+ self.image_tolerances = {
94
+ "mean_vectors_3d_positions": 3.5, # Increased to account for depth scaling with focal length
95
+ "singular_values_scales": 0.035, # Increased proportionally (scales are depth-dependent)
96
+ "quaternions_rotations": 5.0,
97
+ "colors_rgb_linear": 0.01,
98
+ "opacities_alpha_channel": 0.05,
99
+ }
100
+
101
+ if self.angular_tolerances_random is None:
102
+ self.angular_tolerances_random = {
103
+ "mean": 0.01,
104
+ "p99": 0.1,
105
+ "p99_9": 1.0,
106
+ "max": 5.0,
107
+ }
108
+
109
+ if self.angular_tolerances_image is None:
110
+ self.angular_tolerances_image = {
111
+ "mean": 0.2,
112
+ "p99": 2.0,
113
+ "p99_9": 5.0,
114
+ "max": 25.0,
115
+ }
116
+
117
+
118
+ class SharpModelTraceable(nn.Module):
119
+ """Fully traceable version of SHARP for Core ML conversion.
120
+
121
+ This version removes all dynamic control flow and makes the model
122
+ fully traceable with torch.jit.trace.
123
+ """
124
+
125
+ def __init__(self, predictor: RGBGaussianPredictor):
126
+ """Initialize the traceable wrapper.
127
+
128
+ Args:
129
+ predictor: The SHARP RGBGaussianPredictor model.
130
+ """
131
+ super().__init__()
132
+ # Copy all submodules
133
+ self.init_model = predictor.init_model
134
+ self.feature_model = predictor.feature_model
135
+ self.monodepth_model = predictor.monodepth_model
136
+ self.prediction_head = predictor.prediction_head
137
+ self.gaussian_composer = predictor.gaussian_composer
138
+ self.depth_alignment = predictor.depth_alignment
139
+
140
+ # For debugging: store global_scale
141
+ self.last_global_scale = None
142
+ self.last_monodepth_min = None
143
+
144
+ def forward(
145
+ self,
146
+ image: torch.Tensor,
147
+ disparity_factor: torch.Tensor
148
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
149
+ """Run inference with traceable forward pass.
150
+
151
+ Args:
152
+ image: Input image tensor of shape (1, 3, H, W) in range [0, 1].
153
+ disparity_factor: Disparity factor tensor of shape (1,).
154
+
155
+ Returns:
156
+ Tuple of 5 tensors representing 3D Gaussians.
157
+ """
158
+ # Estimate depth using monodepth
159
+ monodepth_output = self.monodepth_model(image)
160
+ monodepth_disparity = monodepth_output.disparity
161
+
162
+ # Convert disparity to depth - use float32 to match Core ML execution
163
+ # Core ML uses float32 precision, so using double() here creates a mismatch
164
+ disparity_factor_expanded = disparity_factor[:, None, None, None]
165
+
166
+ # Clamp disparity to prevent numerical instability (matches model exactly)
167
+ disparity_clamped = monodepth_disparity.clamp(min=1e-4, max=1e4)
168
+ monodepth = disparity_factor_expanded / disparity_clamped
169
+
170
+ # Apply depth alignment (inference mode)
171
+ monodepth, _ = self.depth_alignment(monodepth, None, monodepth_output.decoder_features)
172
+
173
+ # Store monodepth min for debugging (before normalization)
174
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
175
+ self.last_monodepth_min = monodepth.flatten().min().item()
176
+
177
+ # Initialize gaussians
178
+ init_output = self.init_model(image, monodepth)
179
+
180
+ # Store global_scale for debugging
181
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
182
+ if init_output.global_scale is not None:
183
+ self.last_global_scale = init_output.global_scale.item()
184
+
185
+ # Extract features
186
+ image_features = self.feature_model(
187
+ init_output.feature_input,
188
+ encodings=monodepth_output.output_features
189
+ )
190
+
191
+ # Predict deltas
192
+ delta_values = self.prediction_head(image_features)
193
+
194
+ # Compose final gaussians
195
+ gaussians = self.gaussian_composer(
196
+ delta=delta_values,
197
+ base_values=init_output.gaussian_base_values,
198
+ global_scale=init_output.global_scale,
199
+ )
200
+
201
+ # Normalize quaternions for consistent validation and inference
202
+ #
203
+ # IMPORTANT: The SHARP model does NOT canonicalize quaternions during inference.
204
+ # Quaternions are normalized to unit length but retain their sign ambiguity (q ≡ -q).
205
+ #
206
+ # We canonicalize here for two reasons:
207
+ # 1. Numerical validation: Ensures PyTorch and Core ML outputs can be compared directly
208
+ # 2. Consistency: Provides deterministic outputs for the same rotation
209
+ #
210
+ # This canonicalization is NOT required for rendering, as both q and -q represent
211
+ # the same 3D rotation. Renderers typically normalize quaternions internally.
212
+ quaternions = gaussians.quaternions
213
+
214
+ # Normalize quaternions to unit length
215
+ # Use float32 to match Core ML precision
216
+ quat_norm_sq = torch.sum(quaternions * quaternions, dim=-1, keepdim=True)
217
+ quat_norm = torch.sqrt(torch.clamp(quat_norm_sq, min=1e-12))
218
+ quaternions_normalized = quaternions / quat_norm
219
+
220
+ # Apply sign canonicalization for consistent representation
221
+ # Ensure the component with largest absolute value is positive
222
+ abs_quat = torch.abs(quaternions_normalized)
223
+ max_idx = torch.argmax(abs_quat, dim=-1, keepdim=True)
224
+
225
+ # Create one-hot selector for the max component
226
+ one_hot = torch.zeros_like(quaternions_normalized)
227
+ one_hot.scatter_(-1, max_idx, 1.0)
228
+
229
+ # Get the sign of the max component
230
+ max_component_sign = torch.sum(quaternions_normalized * one_hot, dim=-1, keepdim=True)
231
+
232
+ # Canonicalize: flip if max component is negative
233
+ # This matches the validation logic: np.where(max_component_sign < 0, -q, q)
234
+ quaternions = torch.where(max_component_sign < 0, -quaternions_normalized, quaternions_normalized).float()
235
+
236
+ return (
237
+ gaussians.mean_vectors,
238
+ gaussians.singular_values,
239
+ quaternions,
240
+ gaussians.colors,
241
+ gaussians.opacities,
242
+ )
243
+
244
+
245
+ def load_sharp_model(checkpoint_path: Path | None = None) -> RGBGaussianPredictor:
246
+ """Load SHARP model from checkpoint.
247
+
248
+ Args:
249
+ checkpoint_path: Path to the .pt checkpoint file.
250
+ If None, downloads the default model.
251
+
252
+ Returns:
253
+ The loaded RGBGaussianPredictor model in eval mode.
254
+ """
255
+ if checkpoint_path is None:
256
+ LOGGER.info("Downloading default model from %s", DEFAULT_MODEL_URL)
257
+ state_dict = torch.hub.load_state_dict_from_url(DEFAULT_MODEL_URL, progress=True)
258
+ else:
259
+ LOGGER.info("Loading checkpoint from %s", checkpoint_path)
260
+ state_dict = torch.load(checkpoint_path, weights_only=True, map_location="cpu")
261
+
262
+ # Create model with default parameters
263
+ predictor = create_predictor(PredictorParams())
264
+ predictor.load_state_dict(state_dict)
265
+ predictor.eval()
266
+
267
+ return predictor
268
+
269
+
270
+ def convert_to_coreml(
271
+ predictor: RGBGaussianPredictor,
272
+ output_path: Path,
273
+ input_shape: tuple[int, int] = (1536, 1536),
274
+ compute_precision: ct.precision = ct.precision.FLOAT16,
275
+ compute_units: ct.ComputeUnit = ct.ComputeUnit.ALL,
276
+ minimum_deployment_target: ct.target | None = None,
277
+ ) -> ct.models.MLModel:
278
+ """Convert SHARP model to Core ML format.
279
+
280
+ Args:
281
+ predictor: The SHARP RGBGaussianPredictor model.
282
+ output_path: Path to save the .mlmodel file.
283
+ input_shape: Input image shape (height, width). Default is (1536, 1536).
284
+ compute_precision: Precision for compute (FLOAT16 or FLOAT32).
285
+ compute_units: Target compute units (ALL, CPU_AND_GPU, CPU_ONLY, etc.).
286
+ minimum_deployment_target: Minimum iOS/macOS deployment target.
287
+
288
+ Returns:
289
+ The converted Core ML model.
290
+ """
291
+ LOGGER.info("Preparing model for Core ML conversion...")
292
+
293
+ # Ensure depth alignment is disabled for inference
294
+ predictor.depth_alignment.scale_map_estimator = None
295
+
296
+ # Create traceable wrapper
297
+ model_wrapper = SharpModelTraceable(predictor)
298
+ model_wrapper.eval()
299
+
300
+ # Pre-warm the model with a few forward passes for better tracing
301
+ LOGGER.info("Pre-warming model for better tracing...")
302
+ with torch.no_grad():
303
+ for _ in range(3):
304
+ warm_image = torch.randn(1, 3, input_shape[0], input_shape[1])
305
+ warm_disparity = torch.tensor([1.0])
306
+ _ = model_wrapper(warm_image, warm_disparity)
307
+
308
+ # Create deterministic example inputs for tracing (same as validation)
309
+ height, width = input_shape
310
+ torch.manual_seed(42) # Use same seed as validation for consistency
311
+ example_image = torch.randn(1, 3, height, width)
312
+ example_disparity_factor = torch.tensor([1.0])
313
+
314
+ LOGGER.info("Attempting torch.jit.script for better tracing...")
315
+ try:
316
+ with torch.no_grad():
317
+ scripted_model = torch.jit.script(model_wrapper)
318
+ LOGGER.info("torch.jit.script succeeded, using scripted model")
319
+ traced_model = scripted_model
320
+ except Exception as e:
321
+ LOGGER.warning(f"torch.jit.script failed: {e}")
322
+ LOGGER.info("Falling back to torch.jit.trace...")
323
+ with torch.no_grad():
324
+ traced_model = torch.jit.trace(
325
+ model_wrapper,
326
+ (example_image, example_disparity_factor),
327
+ strict=False, # Allow some flexibility for complex models
328
+ check_trace=False, # Skip trace checking to allow more flexibility
329
+ )
330
+
331
+ LOGGER.info("Converting traced model to Core ML...")
332
+
333
+ # Define input types for Core ML
334
+ inputs = [
335
+ ct.TensorType(
336
+ name="image",
337
+ shape=(1, 3, height, width),
338
+ dtype=np.float32,
339
+ ),
340
+ ct.TensorType(
341
+ name="disparity_factor",
342
+ shape=(1,),
343
+ dtype=np.float32,
344
+ ),
345
+ ]
346
+
347
+ # Define output names with clear, descriptive labels
348
+ output_names = [
349
+ "mean_vectors_3d_positions", # 3D positions (NDC space)
350
+ "singular_values_scales", # Scale parameters (diagonal of covariance)
351
+ "quaternions_rotations", # Rotation as quaternions
352
+ "colors_rgb_linear", # RGB colors in linear color space
353
+ "opacities_alpha_channel", # Opacity values (alpha)
354
+ ]
355
+
356
+ # Define outputs with proper names for Core ML conversion
357
+ outputs = [
358
+ ct.TensorType(name=output_names[0], dtype=np.float32),
359
+ ct.TensorType(name=output_names[1], dtype=np.float32),
360
+ ct.TensorType(name=output_names[2], dtype=np.float32),
361
+ ct.TensorType(name=output_names[3], dtype=np.float32),
362
+ ct.TensorType(name=output_names[4], dtype=np.float32),
363
+ ]
364
+
365
+ # Set up conversion config
366
+ conversion_kwargs: dict[str, Any] = {
367
+ "inputs": inputs,
368
+ "outputs": outputs, # Specify output names during conversion
369
+ "convert_to": "mlprogram", # Use ML Program format for better performance
370
+ "compute_precision": compute_precision,
371
+ "compute_units": compute_units,
372
+ }
373
+
374
+ if minimum_deployment_target is not None:
375
+ conversion_kwargs["minimum_deployment_target"] = minimum_deployment_target
376
+
377
+ # Convert to Core ML
378
+ mlmodel = ct.convert(
379
+ traced_model,
380
+ **conversion_kwargs,
381
+ )
382
+
383
+ # Add metadata
384
+ mlmodel.author = "Apple Inc."
385
+ mlmodel.license = "See LICENSE_MODEL in ml-sharp repository"
386
+ mlmodel.short_description = (
387
+ "SHARP: Sharp Monocular View Synthesis - Predicts 3D Gaussian splats from a single image"
388
+ )
389
+ mlmodel.version = "1.0.0"
390
+
391
+ # Update output names and descriptions via spec BEFORE saving
392
+ spec = mlmodel.get_spec()
393
+
394
+ # Input descriptions
395
+ input_descriptions = {
396
+ "image": "RGB image normalized to [0, 1], shape (1, 3, H, W)",
397
+ "disparity_factor": "Focal length / image width ratio, shape (1,)",
398
+ }
399
+
400
+ # Output descriptions with clear intent and units
401
+ output_descriptions = {
402
+ "mean_vectors_3d_positions": (
403
+ "3D positions of Gaussian splats in normalized device coordinates (NDC). "
404
+ "Shape: (1, N, 3), where N is the number of Gaussians."
405
+ ),
406
+ "singular_values_scales": (
407
+ "Scale factors for each Gaussian along its principal axes. "
408
+ "Represents size and anisotropy. Shape: (1, N, 3)."
409
+ ),
410
+ "quaternions_rotations": (
411
+ "Rotation of each Gaussian as a unit quaternion [w, x, y, z]. "
412
+ "Used to orient the ellipsoid. Shape: (1, N, 4)."
413
+ ),
414
+ "colors_rgb_linear": (
415
+ "RGB color values in linear RGB space (not gamma-corrected). "
416
+ "Shape: (1, N, 3), with range [0, 1]."
417
+ ),
418
+ "opacities_alpha_channel": (
419
+ "Opacity value per Gaussian (alpha channel), used for blending. "
420
+ "Shape: (1, N), where values are in [0, 1]."
421
+ ),
422
+ }
423
+
424
+ # Update output names and descriptions
425
+ for i, name in enumerate(output_names):
426
+ if i < len(spec.description.output):
427
+ output = spec.description.output[i]
428
+ output.name = name # Update name
429
+ output.shortDescription = output_descriptions[name] # Add description
430
+
431
+ # Validate output names are set correctly
432
+ LOGGER.info("Output names after update: %s", [o.name for o in spec.description.output])
433
+
434
+ # Save the model with correct names
435
+ LOGGER.info("Saving Core ML model to %s", output_path)
436
+ mlmodel.save(str(output_path))
437
+
438
+ return mlmodel
439
+
440
+
441
+ class QuaternionValidator:
442
+ """Validator for quaternion comparisons with configurable tolerances and outlier analysis."""
443
+
444
+ DEFAULT_ANGULAR_TOLERANCES = {
445
+ "mean": 0.01,
446
+ "p99": 0.5,
447
+ "p99_9": 2.0,
448
+ "max": 15.0,
449
+ }
450
+
451
+ def __init__(
452
+ self,
453
+ angular_tolerances: dict[str, float] | None = None,
454
+ enable_outlier_analysis: bool = True,
455
+ outlier_thresholds: list[float] | None = None,
456
+ ):
457
+ """Initialize validator with tolerances.
458
+
459
+ Args:
460
+ angular_tolerances: Dict with keys 'mean', 'p99', 'p99_9', 'max' for angular diffs in degrees.
461
+ enable_outlier_analysis: Whether to perform detailed outlier analysis.
462
+ outlier_thresholds: List of angle thresholds for outlier counting.
463
+ """
464
+ self.angular_tolerances = angular_tolerances or self.DEFAULT_ANGULAR_TOLERANCES.copy()
465
+ self.enable_outlier_analysis = enable_outlier_analysis
466
+ self.outlier_thresholds = outlier_thresholds or [5.0, 10.0, 15.0]
467
+
468
+ @staticmethod
469
+ def canonicalize_quaternion(q: np.ndarray) -> np.ndarray:
470
+ """Canonicalize quaternion to ensure consistent representation.
471
+
472
+ Ensures the quaternion with the largest absolute component is positive.
473
+ This handles the sign ambiguity where q and -q represent the same rotation.
474
+
475
+ Args:
476
+ q: Quaternion array of shape (..., 4)
477
+
478
+ Returns:
479
+ Canonicalized quaternion array.
480
+ """
481
+ abs_q = np.abs(q)
482
+ max_component_idx = np.argmax(abs_q, axis=-1, keepdims=True)
483
+ selector = np.zeros_like(q)
484
+ np.put_along_axis(selector, max_component_idx, 1.0, axis=-1)
485
+ max_component_sign = np.sum(q * selector, axis=-1, keepdims=True)
486
+ return np.where(max_component_sign < 0, -q, q)
487
+
488
+ @staticmethod
489
+ def compute_angular_differences(
490
+ quats1: np.ndarray, quats2: np.ndarray
491
+ ) -> tuple[np.ndarray, dict[str, float]]:
492
+ """Compute angular differences between two sets of quaternions.
493
+
494
+ Args:
495
+ quats1: First set of quaternions shape (N, 4)
496
+ quats2: Second set of quaternions shape (N, 4)
497
+
498
+ Returns:
499
+ Tuple of (angular_differences in degrees, statistics dict)
500
+ """
501
+ # Normalize quaternions
502
+ norm1 = np.linalg.norm(quats1, axis=-1, keepdims=True)
503
+ norm2 = np.linalg.norm(quats2, axis=-1, keepdims=True)
504
+ quats1_norm = quats1 / np.clip(norm1, 1e-12, None)
505
+ quats2_norm = quats2 / np.clip(norm2, 1e-12, None)
506
+
507
+ # Canonicalize both
508
+ quats1_canon = QuaternionValidator.canonicalize_quaternion(quats1_norm)
509
+ quats2_canon = QuaternionValidator.canonicalize_quaternion(quats2_norm)
510
+
511
+ # Compute dot products for both q·q and q·(-q) to handle sign ambiguity
512
+ dot_products = np.sum(quats1_canon * quats2_canon, axis=-1)
513
+ dot_products_flipped = np.sum(quats1_canon * (-quats2_canon), axis=-1)
514
+
515
+ # Take the maximum absolute dot product (handle sign ambiguity)
516
+ dot_products = np.maximum(np.abs(dot_products), np.abs(dot_products_flipped))
517
+ dot_products = np.clip(dot_products, 0.0, 1.0)
518
+
519
+ # Compute angular differences
520
+ angular_diff_rad = 2.0 * np.arccos(dot_products)
521
+ angular_diff_deg = np.degrees(angular_diff_rad)
522
+
523
+ # Compute statistics
524
+ stats = {
525
+ "mean": float(np.mean(angular_diff_deg)),
526
+ "std": float(np.std(angular_diff_deg)),
527
+ "min": float(np.min(angular_diff_deg)),
528
+ "max": float(np.max(angular_diff_deg)),
529
+ "p50": float(np.percentile(angular_diff_deg, 50)),
530
+ "p90": float(np.percentile(angular_diff_deg, 90)),
531
+ "p99": float(np.percentile(angular_diff_deg, 99)),
532
+ "p99_9": float(np.percentile(angular_diff_deg, 99.9)),
533
+ }
534
+
535
+ return angular_diff_deg, stats
536
+
537
+ def analyze_outliers(
538
+ self, angular_diff_deg: np.ndarray
539
+ ) -> dict[str, dict[str, int | float]]:
540
+ """Analyze outliers in angular differences.
541
+
542
+ Args:
543
+ angular_diff_deg: Array of angular differences in degrees.
544
+
545
+ Returns:
546
+ Dict with outlier statistics for each threshold.
547
+ """
548
+ if not self.enable_outlier_analysis:
549
+ return {}
550
+
551
+ outlier_stats = {}
552
+ total = len(angular_diff_deg)
553
+
554
+ for threshold in self.outlier_thresholds:
555
+ count = int(np.sum(angular_diff_deg > threshold))
556
+ outlier_stats[f">{threshold}°"] = {
557
+ "count": count,
558
+ "percentage": (count / total) * 100.0 if total > 0 else 0.0,
559
+ }
560
+
561
+ return outlier_stats
562
+
563
+ def validate(
564
+ self,
565
+ pt_quaternions: np.ndarray,
566
+ coreml_quaternions: np.ndarray,
567
+ image_name: str = "Unknown",
568
+ ) -> dict:
569
+ """Validate Core ML quaternions against PyTorch quaternions.
570
+
571
+ Args:
572
+ pt_quaternions: PyTorch quaternion outputs.
573
+ coreml_quaternions: Core ML quaternion outputs.
574
+ image_name: Name of the image being validated.
575
+
576
+ Returns:
577
+ Dict with validation results including status, stats, and outliers.
578
+ """
579
+ angular_diff_deg, stats = self.compute_angular_differences(
580
+ pt_quaternions, coreml_quaternions
581
+ )
582
+ outlier_stats = self.analyze_outliers(angular_diff_deg)
583
+
584
+ # Check tolerances
585
+ passed = True
586
+ failure_reasons = []
587
+
588
+ for key, tolerance in self.angular_tolerances.items():
589
+ if key in stats and stats[key] > tolerance:
590
+ passed = False
591
+ failure_reasons.append(
592
+ f"{key} angular {stats[key]:.4f}° > tolerance {tolerance:.4f}°"
593
+ )
594
+
595
+ return {
596
+ "image": image_name,
597
+ "passed": passed,
598
+ "failure_reasons": failure_reasons,
599
+ "stats": stats,
600
+ "outliers": outlier_stats,
601
+ "num_gaussians": len(angular_diff_deg),
602
+ }
603
+
604
+
605
+ def find_coreml_output_key(name: str, coreml_outputs: dict) -> str:
606
+ """Find matching Core ML output key for a given output name.
607
+
608
+ Args:
609
+ name: The expected output name
610
+ coreml_outputs: Dictionary of Core ML outputs
611
+
612
+ Returns:
613
+ The matching key from coreml_outputs
614
+ """
615
+ if name in coreml_outputs:
616
+ return name
617
+
618
+ # Try partial match
619
+ for key in coreml_outputs:
620
+ base_name = name.split('_')[0]
621
+ if base_name in key.lower():
622
+ return key
623
+
624
+ # Fallback to index-based lookup
625
+ output_index = OUTPUT_NAMES.index(name) if name in OUTPUT_NAMES else 0
626
+ return list(coreml_outputs.keys())[output_index]
627
+
628
+
629
+ def run_inference_pair(
630
+ pytorch_model: RGBGaussianPredictor,
631
+ mlmodel: ct.models.MLModel,
632
+ image_tensor: torch.Tensor,
633
+ disparity_factor: float = 1.0,
634
+ log_internals: bool = False,
635
+ ) -> tuple[list[np.ndarray], dict[str, np.ndarray]]:
636
+ """Run inference on both PyTorch and Core ML models.
637
+
638
+ Args:
639
+ pytorch_model: The PyTorch model
640
+ mlmodel: The Core ML model
641
+ image_tensor: Input image tensor
642
+ disparity_factor: Disparity factor value
643
+ log_internals: Whether to log internal values for debugging
644
+
645
+ Returns:
646
+ Tuple of (pytorch_outputs, coreml_outputs)
647
+ """
648
+ # Run PyTorch model
649
+ traceable_wrapper = SharpModelTraceable(pytorch_model)
650
+ traceable_wrapper.eval()
651
+
652
+ # Ensure float32 dtype for model inference
653
+ image_tensor = image_tensor.float()
654
+
655
+ test_disparity_pt = torch.tensor([disparity_factor], dtype=torch.float32)
656
+ with torch.no_grad():
657
+ pt_outputs = traceable_wrapper(image_tensor, test_disparity_pt)
658
+
659
+ # Log internal values if requested
660
+ if log_internals:
661
+ if hasattr(traceable_wrapper, 'last_global_scale') and traceable_wrapper.last_global_scale is not None:
662
+ LOGGER.info(f"PyTorch global_scale: {traceable_wrapper.last_global_scale:.6f}")
663
+ if hasattr(traceable_wrapper, 'last_monodepth_min') and traceable_wrapper.last_monodepth_min is not None:
664
+ LOGGER.info(f"PyTorch monodepth_min: {traceable_wrapper.last_monodepth_min:.6f}")
665
+
666
+ # Convert to numpy
667
+ pt_outputs_np = [o.numpy() for o in pt_outputs]
668
+
669
+ # Run Core ML model
670
+ test_image_np = image_tensor.numpy()
671
+ test_disparity_np = np.array([disparity_factor], dtype=np.float32)
672
+ coreml_inputs = {
673
+ "image": test_image_np,
674
+ "disparity_factor": test_disparity_np,
675
+ }
676
+ coreml_outputs = mlmodel.predict(coreml_inputs)
677
+
678
+ return pt_outputs_np, coreml_outputs
679
+
680
+
681
+ def compare_outputs(
682
+ pt_outputs: list[np.ndarray],
683
+ coreml_outputs: dict[str, np.ndarray],
684
+ tolerances: dict[str, float],
685
+ quat_validator: QuaternionValidator,
686
+ image_name: str = "Unknown",
687
+ ) -> list[dict]:
688
+ """Compare PyTorch and Core ML outputs.
689
+
690
+ Args:
691
+ pt_outputs: List of PyTorch outputs
692
+ coreml_outputs: Dictionary of Core ML outputs
693
+ tolerances: Tolerance values per output type
694
+ quat_validator: QuaternionValidator instance
695
+ image_name: Name of the image being validated
696
+
697
+ Returns:
698
+ List of validation result dictionaries
699
+ """
700
+ validation_results = []
701
+
702
+ for i, name in enumerate(OUTPUT_NAMES):
703
+ pt_output = pt_outputs[i]
704
+ coreml_key = find_coreml_output_key(name, coreml_outputs)
705
+ coreml_output = coreml_outputs[coreml_key]
706
+
707
+ result = {"output": name, "passed": True, "failure_reason": ""}
708
+
709
+ if name == "quaternions_rotations":
710
+ # Use QuaternionValidator for quaternions
711
+ quat_result = quat_validator.validate(pt_output, coreml_output, image_name=image_name)
712
+
713
+ result.update({
714
+ "max_diff": f"{quat_result['stats']['max']:.6f}",
715
+ "mean_diff": f"{quat_result['stats']['mean']:.6f}",
716
+ "p99_diff": f"{quat_result['stats']['p99']:.6f}",
717
+ "passed": quat_result["passed"],
718
+ "failure_reason": "; ".join(quat_result["failure_reasons"]) if quat_result["failure_reasons"] else "",
719
+ })
720
+ else:
721
+ # Standard numerical comparison
722
+ diff = np.abs(pt_output - coreml_output)
723
+ output_tolerance = tolerances.get(name, 0.01)
724
+ max_diff = np.max(diff)
725
+
726
+ result.update({
727
+ "max_diff": f"{max_diff:.6f}",
728
+ "mean_diff": f"{np.mean(diff):.6f}",
729
+ "p99_diff": f"{np.percentile(diff, 99):.6f}",
730
+ })
731
+
732
+ if max_diff > output_tolerance:
733
+ result["passed"] = False
734
+ result["failure_reason"] = f"max diff {max_diff:.6f} > tolerance {output_tolerance:.6f}"
735
+
736
+ validation_results.append(result)
737
+
738
+ return validation_results
739
+
740
+
741
+ def format_validation_table(
742
+ validation_results: list[dict],
743
+ image_name: str,
744
+ include_image_column: bool = False,
745
+ ) -> str:
746
+ """Format validation results as a markdown table.
747
+
748
+ Args:
749
+ validation_results: List of validation result dicts with keys:
750
+ output, max_diff, mean_diff, p99_diff, passed, etc.
751
+ image_name: Name of the image being validated.
752
+ include_image_column: Whether to include the image name as a column.
753
+
754
+ Returns:
755
+ Formatted markdown table as a string.
756
+ """
757
+ lines = []
758
+
759
+ if include_image_column:
760
+ lines.append("| Image | Output | Max Diff | Mean Diff | P99 Diff | Status |")
761
+ lines.append("|-------|--------|----------|-----------|----------|--------|")
762
+
763
+ for result in validation_results:
764
+ output_name = result["output"].replace("_", " ").title()
765
+ status = "✅ PASS" if result["passed"] else "❌ FAIL"
766
+ lines.append(
767
+ f"| {image_name} | {output_name} | {result['max_diff']} | "
768
+ f"{result['mean_diff']} | {result['p99_diff']} | {status} |"
769
+ )
770
+ else:
771
+ lines.append("| Output | Max Diff | Mean Diff | P99 Diff | Status |")
772
+ lines.append("|--------|----------|-----------|----------|--------|")
773
+
774
+ for result in validation_results:
775
+ output_name = result["output"].replace("_", " ").title()
776
+ status = "✅ PASS" if result["passed"] else "❌ FAIL"
777
+ lines.append(
778
+ f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | "
779
+ f"{result['p99_diff']} | {status} |"
780
+ )
781
+
782
+ return "\n".join(lines)
783
+
784
+
785
+ def validate_coreml_model(
786
+ mlmodel: ct.models.MLModel,
787
+ pytorch_model: RGBGaussianPredictor,
788
+ input_shape: tuple[int, int] = (1536, 1536),
789
+ tolerance: float = 0.01,
790
+ angular_tolerances: dict[str, float] | None = None,
791
+ ) -> bool:
792
+ """Validate Core ML model outputs against PyTorch model.
793
+
794
+ Args:
795
+ mlmodel: The Core ML model to validate.
796
+ pytorch_model: The original PyTorch model.
797
+ input_shape: Input image shape (height, width).
798
+ tolerance: Maximum allowed difference between outputs.
799
+ angular_tolerances: Dict with keys 'mean', 'p99', 'p99_9', 'max' for angular diffs in degrees.
800
+
801
+ Returns:
802
+ True if validation passes, False otherwise.
803
+ """
804
+ LOGGER.info("Validating Core ML model against PyTorch...")
805
+
806
+ height, width = input_shape
807
+
808
+ # Set seeds for reproducibility
809
+ np.random.seed(42)
810
+ torch.manual_seed(42)
811
+
812
+ # Create test input
813
+ test_image_np = np.random.rand(1, 3, height, width).astype(np.float32)
814
+ test_disparity = np.array([1.0], dtype=np.float32)
815
+
816
+ # Run PyTorch model
817
+ test_image_pt = torch.from_numpy(test_image_np)
818
+ test_disparity_pt = torch.from_numpy(test_disparity)
819
+
820
+ traceable_wrapper = SharpModelTraceable(pytorch_model)
821
+ traceable_wrapper.eval()
822
+
823
+ with torch.no_grad():
824
+ pt_outputs = traceable_wrapper(test_image_pt, test_disparity_pt)
825
+
826
+ # Run Core ML model
827
+ coreml_inputs = {
828
+ "image": test_image_np,
829
+ "disparity_factor": test_disparity,
830
+ }
831
+ coreml_outputs = mlmodel.predict(coreml_inputs)
832
+
833
+ LOGGER.info(f"PyTorch outputs shapes: {[o.shape for o in pt_outputs]}")
834
+ LOGGER.info(f"Core ML outputs keys: {list(coreml_outputs.keys())}")
835
+
836
+ # Output configuration
837
+ output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"]
838
+
839
+ # Define tolerances per output type
840
+ tolerances = {
841
+ "mean_vectors_3d_positions": 0.001,
842
+ "singular_values_scales": 0.0001,
843
+ "quaternions_rotations": 2.0,
844
+ "colors_rgb_linear": 0.002,
845
+ "opacities_alpha_channel": 0.005,
846
+ }
847
+
848
+ # Use provided angular tolerances or defaults
849
+ if angular_tolerances is None:
850
+ angular_tolerances = {
851
+ "mean": 0.01,
852
+ "p99": 0.1,
853
+ "p99_9": 1.0,
854
+ "max": 5.0,
855
+ }
856
+
857
+ # Initialize quaternion validator
858
+ quat_validator = QuaternionValidator(angular_tolerances=angular_tolerances)
859
+
860
+ all_passed = True
861
+
862
+ # Additional diagnostics for depth/position analysis
863
+ LOGGER.info("=== Depth/Position Statistics ===")
864
+ pt_positions = pt_outputs[0].numpy()
865
+ coreml_key = [k for k in coreml_outputs.keys() if "mean_vectors" in k][0]
866
+ coreml_positions = coreml_outputs[coreml_key]
867
+
868
+ LOGGER.info(f"PyTorch positions - Z range: [{pt_positions[..., 2].min():.4f}, {pt_positions[..., 2].max():.4f}], mean: {pt_positions[..., 2].mean():.4f}, std: {pt_positions[..., 2].std():.4f}")
869
+ LOGGER.info(f"CoreML positions - Z range: [{coreml_positions[..., 2].min():.4f}, {coreml_positions[..., 2].max():.4f}], mean: {coreml_positions[..., 2].mean():.4f}, std: {coreml_positions[..., 2].std():.4f}")
870
+
871
+ z_diff = np.abs(pt_positions[..., 2] - coreml_positions[..., 2])
872
+ LOGGER.info(f"Z-coordinate difference - max: {z_diff.max():.6f}, mean: {z_diff.mean():.6f}, std: {z_diff.std():.6f}")
873
+ LOGGER.info("=================================")
874
+
875
+ # Collect validation results
876
+ validation_results = []
877
+
878
+ for i, name in enumerate(output_names):
879
+ pt_output = pt_outputs[i].numpy()
880
+
881
+ # Find matching Core ML output
882
+ coreml_key = None
883
+ if name in coreml_outputs:
884
+ coreml_key = name
885
+ else:
886
+ # Try partial match
887
+ for key in coreml_outputs:
888
+ base_name = name.split('_')[0]
889
+ if base_name in key.lower():
890
+ coreml_key = key
891
+ break
892
+ if coreml_key is None:
893
+ coreml_key = list(coreml_outputs.keys())[i]
894
+
895
+ coreml_output = coreml_outputs[coreml_key]
896
+ result = {"output": name, "passed": True, "failure_reason": ""}
897
+
898
+ # Special handling for quaternions
899
+ if name == "quaternions_rotations":
900
+ # Use the new QuaternionValidator
901
+ quat_result = quat_validator.validate(pt_output, coreml_output, image_name="Random")
902
+
903
+ result.update({
904
+ "max_diff": f"{quat_result['stats']['max']:.6f}",
905
+ "mean_diff": f"{quat_result['stats']['mean']:.6f}",
906
+ "p99_diff": f"{quat_result['stats']['p99']:.6f}",
907
+ "p99_9_diff": f"{quat_result['stats']['p99_9']:.6f}",
908
+ "max_angular": f"{quat_result['stats']['max']:.4f}",
909
+ "mean_angular": f"{quat_result['stats']['mean']:.4f}",
910
+ "p99_angular": f"{quat_result['stats']['p99']:.4f}",
911
+ "passed": quat_result["passed"],
912
+ "failure_reason": "; ".join(quat_result["failure_reasons"]) if quat_result["failure_reasons"] else "",
913
+ "quat_stats": quat_result["stats"],
914
+ "outliers": quat_result["outliers"],
915
+ })
916
+ if not quat_result["passed"]:
917
+ all_passed = False
918
+ else:
919
+ diff = np.abs(pt_output - coreml_output)
920
+ output_tolerance = tolerances.get(name, tolerance)
921
+ result.update({
922
+ "max_diff": f"{np.max(diff):.6f}",
923
+ "mean_diff": f"{np.mean(diff):.6f}",
924
+ "p99_diff": f"{np.percentile(diff, 99):.6f}",
925
+ "tolerance": f"{output_tolerance:.6f}"
926
+ })
927
+ if np.max(diff) > output_tolerance:
928
+ result["passed"] = False
929
+ result["failure_reason"] = f"max diff {np.max(diff):.6f} > tolerance {output_tolerance:.6f}"
930
+ all_passed = False
931
+
932
+ validation_results.append(result)
933
+
934
+ # Output validation results as markdown table
935
+ LOGGER.info("\n### Validation Results\n")
936
+ LOGGER.info("| Output | Max Diff | Mean Diff | P99 Diff | P99.9 Diff | Angular Diff (°) | Status |")
937
+ LOGGER.info("|--------|----------|-----------|----------|------------|------------------|--------|")
938
+
939
+ for result in validation_results:
940
+ output_name = result["output"].replace("_", " ").title()
941
+ if "max_angular" in result:
942
+ angular_info = f"{result['max_angular']} / {result['mean_angular']} / {result['p99_angular']}"
943
+ p99_9 = result.get("p99_9_diff", "-")
944
+ status = "✅ PASS" if result["passed"] else f"❌ FAIL"
945
+ LOGGER.info(f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | {result['p99_diff']} | {p99_9} | {angular_info} | {status} |")
946
+ else:
947
+ status = "✅ PASS" if result["passed"] else f"❌ FAIL"
948
+ LOGGER.info(f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | {result['p99_diff']} | - | - | {status} |")
949
+ LOGGER.info("")
950
+
951
+ # Output quaternion outlier analysis if available
952
+ for result in validation_results:
953
+ if "outliers" in result and result["outliers"]:
954
+ LOGGER.info("### Quaternion Outlier Analysis\n")
955
+ LOGGER.info(f"| Threshold | Count | Percentage |")
956
+ LOGGER.info("|-----------|-------|------------|")
957
+ for threshold, data in result["outliers"].items():
958
+ LOGGER.info(f"| {threshold} | {data['count']} | {data['percentage']:.4f}% |")
959
+ LOGGER.info("")
960
+
961
+ return all_passed
962
+
963
+
964
+ def load_and_preprocess_image(
965
+ image_path: Path,
966
+ target_size: tuple[int, int] = (1536, 1536),
967
+ ) -> tuple[torch.Tensor, float, tuple[int, int]]:
968
+ """Load and preprocess an input image for SHARP inference.
969
+
970
+ Args:
971
+ image_path: Path to the input image file.
972
+ target_size: Target (height, width) for resizing.
973
+
974
+ Returns:
975
+ Tuple of (preprocessed image tensor, focal_length_px, original_size)
976
+ - Preprocessed image tensor of shape (1, 3, H, W) in range [0, 1]
977
+ - Focal length in pixels (from EXIF or default)
978
+ - Original image size (width, height)
979
+ """
980
+ LOGGER.info(f"Loading image from {image_path}")
981
+
982
+ # Use the SHARP io utilities to load image with focal length
983
+ image_np, original_size, f_px = io.load_rgb(image_path)
984
+ LOGGER.info(f"Original image size: {original_size}, focal length: {f_px:.2f}px")
985
+
986
+ # Convert to torch and normalize - ensure float32 dtype
987
+ # io.load_rgb returns uint8, convert to float32 explicitly
988
+ image_tensor = torch.from_numpy(image_np).float() / 255.0
989
+ image_tensor = image_tensor.permute(2, 0, 1) # HWC -> CHW
990
+ original_height, original_width = image_np.shape[:2]
991
+
992
+ # Resize to target size if different
993
+ if (original_width, original_height) != (target_size[1], target_size[0]):
994
+ LOGGER.info(f"Resizing to {target_size[1]}x{target_size[0]}")
995
+ import torch.nn.functional as F
996
+ image_tensor = F.interpolate(
997
+ image_tensor.unsqueeze(0),
998
+ size=(target_size[0], target_size[1]),
999
+ mode="bilinear",
1000
+ align_corners=True,
1001
+ ).squeeze(0)
1002
+
1003
+ # Add batch dimension
1004
+ image_tensor = image_tensor.unsqueeze(0) # (1, 3, H, W)
1005
+
1006
+ LOGGER.info(f"Preprocessed image shape: {image_tensor.shape}, range: [{image_tensor.min():.4f}, {image_tensor.max():.4f}]")
1007
+
1008
+ return image_tensor, f_px, (original_width, original_height)
1009
+
1010
+
1011
+ def validate_with_image(
1012
+ mlmodel: ct.models.MLModel,
1013
+ pytorch_model: RGBGaussianPredictor,
1014
+ image_path: Path,
1015
+ input_shape: tuple[int, int] = (1536, 1536),
1016
+ ) -> bool:
1017
+ """Validate Core ML model outputs against PyTorch model using a real input image.
1018
+
1019
+ Args:
1020
+ mlmodel: The Core ML model to validate.
1021
+ pytorch_model: The original PyTorch model.
1022
+ image_path: Path to the input image file.
1023
+ input_shape: Expected input image shape (height, width).
1024
+
1025
+ Returns:
1026
+ True if validation passes, False otherwise.
1027
+ """
1028
+ LOGGER.info("=" * 60)
1029
+ LOGGER.info("Validating Core ML model against PyTorch with real image")
1030
+ LOGGER.info("=" * 60)
1031
+
1032
+ # Load and preprocess the input image
1033
+ test_image = load_and_preprocess_image(image_path, input_shape)
1034
+ test_disparity = np.array([1.0], dtype=np.float32)
1035
+
1036
+ # Run PyTorch model
1037
+ traceable_wrapper = SharpModelTraceable(pytorch_model)
1038
+ traceable_wrapper.eval()
1039
+
1040
+ with torch.no_grad():
1041
+ pt_outputs = traceable_wrapper(test_image, torch.from_numpy(test_disparity))
1042
+
1043
+ LOGGER.info(f"PyTorch outputs shapes: {[o.shape for o in pt_outputs]}")
1044
+
1045
+ # Run Core ML model
1046
+ test_image_np = test_image.numpy()
1047
+ coreml_inputs = {
1048
+ "image": test_image_np,
1049
+ "disparity_factor": test_disparity,
1050
+ }
1051
+ coreml_outputs = mlmodel.predict(coreml_inputs)
1052
+
1053
+ LOGGER.info(f"Core ML outputs keys: {list(coreml_outputs.keys())}")
1054
+
1055
+ # Output configuration
1056
+ output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"]
1057
+
1058
+ # Define tolerances per output type for real image validation
1059
+ # Using p99-based tolerances to handle outliers better
1060
+ tolerances = {
1061
+ "mean_vectors_3d_positions": 1.2,
1062
+ "singular_values_scales": 0.01,
1063
+ "quaternions_rotations": 5.0,
1064
+ "colors_rgb_linear": 0.01,
1065
+ "opacities_alpha_channel": 0.05,
1066
+ }
1067
+
1068
+ # Angular tolerances for quaternions (in degrees)
1069
+ angular_tolerances = {
1070
+ "mean": 0.1,
1071
+ "p99": 1.0,
1072
+ "max": 15.0,
1073
+ }
1074
+
1075
+ all_passed = True
1076
+
1077
+ # Log input image statistics
1078
+ LOGGER.info(f"\n=== Input Image Statistics ===")
1079
+ LOGGER.info(f"Image path: {image_path}")
1080
+ LOGGER.info(f"Image shape: {test_image.shape}")
1081
+ LOGGER.info(f"Image range: [{test_image.min():.4f}, {test_image.max():.4f}]")
1082
+ LOGGER.info(f"Image mean: {test_image.mean(dim=[1,2,3]).tolist()}")
1083
+ LOGGER.info("=" * 30)
1084
+
1085
+ # Depth/position analysis
1086
+ pt_positions = pt_outputs[0].numpy()
1087
+ coreml_key = [k for k in coreml_outputs.keys() if "mean_vectors" in k][0]
1088
+ coreml_positions = coreml_outputs[coreml_key]
1089
+
1090
+ LOGGER.info("\n=== Depth/Position Statistics ===")
1091
+ LOGGER.info(f"PyTorch positions - Z range: [{pt_positions[..., 2].min():.4f}, {pt_positions[..., 2].max():.4f}], mean: {pt_positions[..., 2].mean():.4f}, std: {pt_positions[..., 2].std():.4f}")
1092
+ LOGGER.info(f"CoreML positions - Z range: [{coreml_positions[..., 2].min():.4f}, {coreml_positions[..., 2].max():.4f}], mean: {coreml_positions[..., 2].mean():.4f}, std: {coreml_positions[..., 2].std():.4f}")
1093
+
1094
+ z_diff = np.abs(pt_positions[..., 2] - coreml_positions[..., 2])
1095
+ LOGGER.info(f"Z-coordinate difference - max: {z_diff.max():.6f}, mean: {z_diff.mean():.6f}, std: {z_diff.std():.6f}")
1096
+ LOGGER.info("=================================\n")
1097
+
1098
+ # Collect validation results
1099
+ validation_results = []
1100
+
1101
+ for i, name in enumerate(output_names):
1102
+ pt_output = pt_outputs[i].numpy()
1103
+
1104
+ # Find matching Core ML output
1105
+ coreml_key = None
1106
+ if name in coreml_outputs:
1107
+ coreml_key = name
1108
+ else:
1109
+ # Try partial match
1110
+ for key in coreml_outputs:
1111
+ base_name = name.split('_')[0]
1112
+ if base_name in key.lower():
1113
+ coreml_key = key
1114
+ break
1115
+ if coreml_key is None:
1116
+ coreml_key = list(coreml_outputs.keys())[i]
1117
+
1118
+ coreml_output = coreml_outputs[coreml_key]
1119
+ result = {"output": name, "passed": True, "failure_reason": ""}
1120
+
1121
+ # Special handling for quaternions
1122
+ if name == "quaternions_rotations":
1123
+ pt_quat_norm = np.linalg.norm(pt_output, axis=-1, keepdims=True)
1124
+ pt_output_normalized = pt_output / np.clip(pt_quat_norm, 1e-12, None)
1125
+
1126
+ coreml_quat_norm = np.linalg.norm(coreml_output, axis=-1, keepdims=True)
1127
+ coreml_output_normalized = coreml_output / np.clip(coreml_quat_norm, 1e-12, None)
1128
+
1129
+ def canonicalize_quaternion(q):
1130
+ abs_q = np.abs(q)
1131
+ max_component_idx = np.argmax(abs_q, axis=-1, keepdims=True)
1132
+ selector = np.zeros_like(q)
1133
+ np.put_along_axis(selector, max_component_idx, 1, axis=-1)
1134
+ max_component_sign = np.sum(q * selector, axis=-1, keepdims=True)
1135
+ return np.where(max_component_sign < 0, -q, q)
1136
+
1137
+ pt_output_canonical = canonicalize_quaternion(pt_output_normalized)
1138
+ coreml_output_canonical = canonicalize_quaternion(coreml_output_normalized)
1139
+
1140
+ diff = np.abs(pt_output_canonical - coreml_output_canonical)
1141
+ dot_products = np.sum(pt_output_canonical * coreml_output_canonical, axis=-1)
1142
+ dot_products_flipped = np.sum(pt_output_canonical * (-coreml_output_canonical), axis=-1)
1143
+ # Take the absolute value and ensure we compare q with -q if needed
1144
+ # This handles the sign ambiguity: q and -q represent the same rotation
1145
+ dot_products = np.where(
1146
+ np.abs(dot_products) > np.abs(dot_products_flipped),
1147
+ np.abs(dot_products),
1148
+ np.abs(dot_products_flipped)
1149
+ )
1150
+ dot_products = np.clip(dot_products, 0.0, 1.0)
1151
+ angular_diff_rad = 2 * np.arccos(dot_products)
1152
+ angular_diff_deg = np.degrees(angular_diff_rad)
1153
+ max_angular = np.max(angular_diff_deg)
1154
+ mean_angular = np.mean(angular_diff_deg)
1155
+ p99_angular = np.percentile(angular_diff_deg, 99)
1156
+
1157
+ quat_passed = True
1158
+ failure_reasons = []
1159
+
1160
+ if mean_angular > angular_tolerances["mean"]:
1161
+ quat_passed = False
1162
+ failure_reasons.append(f"mean angular {mean_angular:.4f}° > {angular_tolerances['mean']:.4f}°")
1163
+ if p99_angular > angular_tolerances["p99"]:
1164
+ quat_passed = False
1165
+ failure_reasons.append(f"p99 angular {p99_angular:.4f}° > {angular_tolerances['p99']:.4f}°")
1166
+ if max_angular > angular_tolerances["max"]:
1167
+ quat_passed = False
1168
+ failure_reasons.append(f"max angular {max_angular:.4f}° > {angular_tolerances['max']:.4f}°")
1169
+
1170
+ result.update({
1171
+ "max_diff": f"{np.max(diff):.6f}",
1172
+ "mean_diff": f"{np.mean(diff):.6f}",
1173
+ "p99_diff": f"{np.percentile(diff, 99):.6f}",
1174
+ "max_angular": f"{max_angular:.4f}",
1175
+ "mean_angular": f"{mean_angular:.4f}",
1176
+ "p99_angular": f"{p99_angular:.4f}",
1177
+ "passed": quat_passed,
1178
+ "failure_reason": "; ".join(failure_reasons) if failure_reasons else ""
1179
+ })
1180
+ if not quat_passed:
1181
+ all_passed = False
1182
+ else:
1183
+ diff = np.abs(pt_output - coreml_output)
1184
+ output_tolerance = tolerances.get(name, 0.01)
1185
+ result.update({
1186
+ "max_diff": f"{np.max(diff):.6f}",
1187
+ "mean_diff": f"{np.mean(diff):.6f}",
1188
+ "p99_diff": f"{np.percentile(diff, 99):.6f}",
1189
+ "tolerance": f"{output_tolerance:.6f}"
1190
+ })
1191
+ if np.max(diff) > output_tolerance:
1192
+ result["passed"] = False
1193
+ result["failure_reason"] = f"max diff {np.max(diff):.6f} > tolerance {output_tolerance:.6f}"
1194
+ all_passed = False
1195
+
1196
+ validation_results.append(result)
1197
+
1198
+ # Output validation results as markdown table
1199
+ LOGGER.info("\n### Image Validation Results\n")
1200
+ LOGGER.info(f"| Output | Max Diff | Mean Diff | P99 Diff | Angular Diff (°) | Status |")
1201
+ LOGGER.info(f"|--------|----------|-----------|----------|------------------|--------|")
1202
+
1203
+ for result in validation_results:
1204
+ output_name = result["output"].replace("_", " ").title()
1205
+ if "max_angular" in result:
1206
+ angular_info = f"{result['max_angular']} / {result['mean_angular']} / {result['p99_angular']}"
1207
+ else:
1208
+ angular_info = "-"
1209
+ status = "✅ PASS" if result["passed"] else f"❌ FAIL"
1210
+ LOGGER.info(f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | {result['p99_diff']} | {angular_info} | {status} |")
1211
+ LOGGER.info("")
1212
+
1213
+ return all_passed
1214
+
1215
+
1216
+ def validate_with_image_set(
1217
+ mlmodel: ct.models.MLModel,
1218
+ pytorch_model: RGBGaussianPredictor,
1219
+ image_paths: list[Path],
1220
+ input_shape: tuple[int, int] = (1536, 1536),
1221
+ ) -> bool:
1222
+ """Validate Core ML model against PyTorch using multiple input images.
1223
+
1224
+ Args:
1225
+ mlmodel: The Core ML model to validate.
1226
+ pytorch_model: The original PyTorch model.
1227
+ image_paths: List of paths to input images for validation.
1228
+ input_shape: Expected input image shape (height, width).
1229
+
1230
+ Returns:
1231
+ True if all validations pass, False otherwise.
1232
+ """
1233
+ LOGGER.info("=" * 60)
1234
+ LOGGER.info(f"Validating Core ML model with {len(image_paths)} images")
1235
+ LOGGER.info("=" * 60)
1236
+
1237
+ # Angular tolerances for image validation (more lenient than random validation)
1238
+ # Real images have more variation than random noise
1239
+ angular_tolerances = {
1240
+ "mean": 0.2,
1241
+ "p99": 2.0,
1242
+ "p99_9": 5.0,
1243
+ "max": 25.0,
1244
+ }
1245
+
1246
+ # Initialize quaternion validator
1247
+ quat_validator = QuaternionValidator(angular_tolerances=angular_tolerances)
1248
+
1249
+ all_passed = True
1250
+ all_validation_results = []
1251
+
1252
+ for image_path in image_paths:
1253
+ if not image_path.exists():
1254
+ LOGGER.error(f"Input image not found: {image_path}")
1255
+ all_passed = False
1256
+ continue
1257
+
1258
+ LOGGER.info(f"\n--- Validating with {image_path.name} ---")
1259
+
1260
+ # Run validation for this image and collect detailed results
1261
+ image_results = validate_with_single_image_detailed(
1262
+ mlmodel, pytorch_model, image_path, input_shape, quat_validator
1263
+ )
1264
+
1265
+ # Add image name to each result
1266
+ for result in image_results:
1267
+ result["image"] = image_path.name
1268
+ all_validation_results.append(result)
1269
+
1270
+ # Check if any results failed
1271
+ if not all(r["passed"] for r in image_results):
1272
+ all_passed = False
1273
+
1274
+ # Output combined summary table with all images and outputs
1275
+ LOGGER.info("\n" + "=" * 60)
1276
+ LOGGER.info("### Multi-Image Validation Summary")
1277
+ LOGGER.info("=" * 60 + "\n")
1278
+
1279
+ # Generate combined table
1280
+ if all_validation_results:
1281
+ table = format_validation_table(all_validation_results, "", include_image_column=True)
1282
+ LOGGER.info(table)
1283
+ LOGGER.info("")
1284
+
1285
+ return all_passed
1286
+
1287
+
1288
+ def validate_with_single_image_detailed(
1289
+ mlmodel: ct.models.MLModel,
1290
+ pytorch_model: RGBGaussianPredictor,
1291
+ image_path: Path,
1292
+ input_shape: tuple[int, int],
1293
+ quat_validator: QuaternionValidator | None = None,
1294
+ ) -> list[dict]:
1295
+ """Validate with a single image and return detailed results.
1296
+
1297
+ Args:
1298
+ mlmodel: The Core ML model to validate.
1299
+ pytorch_model: The original PyTorch model.
1300
+ image_path: Path to the input image file.
1301
+ input_shape: Expected input image shape.
1302
+ quat_validator: Optional QuaternionValidator instance.
1303
+
1304
+ Returns:
1305
+ List of validation result dictionaries.
1306
+ """
1307
+ # Load and preprocess the input image with focal length
1308
+ test_image, f_px, (orig_width, orig_height) = load_and_preprocess_image(image_path, input_shape)
1309
+
1310
+ # Compute disparity_factor as focal_length / width (matching predict.py)
1311
+ disparity_factor = f_px / orig_width
1312
+ LOGGER.info(f"Using disparity_factor = {disparity_factor:.6f} (f_px={f_px:.2f} / width={orig_width})")
1313
+
1314
+ # Run inference on both models
1315
+ pt_outputs, coreml_outputs = run_inference_pair(
1316
+ pytorch_model, mlmodel, test_image,
1317
+ disparity_factor=disparity_factor,
1318
+ log_internals=True
1319
+ )
1320
+
1321
+ # Log depth/position statistics for debugging
1322
+ pt_positions = pt_outputs[0]
1323
+ coreml_key = find_coreml_output_key("mean_vectors_3d_positions", coreml_outputs)
1324
+ coreml_positions = coreml_outputs[coreml_key]
1325
+
1326
+ # Detailed position analysis
1327
+ LOGGER.info(f"=== Depth/Position Statistics ({image_path.name}) ===")
1328
+ LOGGER.info(f"PyTorch positions - Z range: [{pt_positions[..., 2].min():.4f}, {pt_positions[..., 2].max():.4f}], mean: {pt_positions[..., 2].mean():.4f}")
1329
+ LOGGER.info(f"CoreML positions - Z range: [{coreml_positions[..., 2].min():.4f}, {coreml_positions[..., 2].max():.4f}], mean: {coreml_positions[..., 2].mean():.4f}")
1330
+
1331
+ # Analyze position differences
1332
+ pos_diff = np.abs(pt_positions - coreml_positions)
1333
+ LOGGER.info(f"Position difference (X,Y,Z) - max: [{pos_diff[..., 0].max():.6f}, {pos_diff[..., 1].max():.6f}, {pos_diff[..., 2].max():.6f}]")
1334
+ LOGGER.info(f"Position difference (X,Y,Z) - mean: [{pos_diff[..., 0].mean():.6f}, {pos_diff[..., 1].mean():.6f}, {pos_diff[..., 2].mean():.6f}]")
1335
+
1336
+ # Check if error is proportional to depth (would indicate global_scale issue)
1337
+ z_diff = np.abs(pt_positions[..., 2] - coreml_positions[..., 2])
1338
+ z_ratio = z_diff / np.clip(pt_positions[..., 2], 1e-6, None)
1339
+ LOGGER.info(f"Z relative error - mean: {z_ratio.mean()*100:.4f}%, max: {z_ratio.max()*100:.4f}%")
1340
+
1341
+ # Log scales for comparison
1342
+ pt_scales = pt_outputs[1]
1343
+ coreml_scales_key = find_coreml_output_key("singular_values_scales", coreml_outputs)
1344
+ coreml_scales = coreml_outputs[coreml_scales_key]
1345
+ scales_diff = np.abs(pt_scales - coreml_scales)
1346
+ scales_ratio = scales_diff / np.clip(pt_scales, 1e-6, None)
1347
+ LOGGER.info(f"Scales relative error - mean: {scales_ratio.mean()*100:.4f}%, max: {scales_ratio.max()*100:.4f}%")
1348
+
1349
+ # Tolerances for real image validation
1350
+ tolerance_config = ToleranceConfig()
1351
+ tolerances = tolerance_config.image_tolerances
1352
+
1353
+ # Use provided validator or create default with image tolerances
1354
+ if quat_validator is None:
1355
+ quat_validator = QuaternionValidator(
1356
+ angular_tolerances=tolerance_config.angular_tolerances_image
1357
+ )
1358
+
1359
+ # Compare outputs
1360
+ validation_results = compare_outputs(
1361
+ pt_outputs,
1362
+ coreml_outputs,
1363
+ tolerances,
1364
+ quat_validator,
1365
+ image_name=image_path.name
1366
+ )
1367
+
1368
+ return validation_results
1369
+
1370
+
1371
+ def validate_with_single_image(
1372
+ mlmodel: ct.models.MLModel,
1373
+ pytorch_model: RGBGaussianPredictor,
1374
+ image_path: Path,
1375
+ input_shape: tuple[int, int],
1376
+ quat_validator: QuaternionValidator | None = None,
1377
+ ) -> bool:
1378
+ """Validate with a single image using the new QuaternionValidator.
1379
+
1380
+ Args:
1381
+ mlmodel: The Core ML model to validate.
1382
+ pytorch_model: The original PyTorch model.
1383
+ image_path: Path to the input image file.
1384
+ input_shape: Expected input image shape.
1385
+ quat_validator: Optional QuaternionValidator instance.
1386
+
1387
+ Returns:
1388
+ True if validation passes, False otherwise.
1389
+ """
1390
+ # Load and preprocess the input image
1391
+ test_image = load_and_preprocess_image(image_path, input_shape)
1392
+ test_disparity = np.array([1.0], dtype=np.float32)
1393
+
1394
+ # Run PyTorch model
1395
+ traceable_wrapper = SharpModelTraceable(pytorch_model)
1396
+ traceable_wrapper.eval()
1397
+
1398
+ with torch.no_grad():
1399
+ pt_outputs = traceable_wrapper(test_image, torch.from_numpy(test_disparity))
1400
+
1401
+ # Run Core ML model
1402
+ test_image_np = test_image.numpy()
1403
+ coreml_inputs = {
1404
+ "image": test_image_np,
1405
+ "disparity_factor": test_disparity,
1406
+ }
1407
+ coreml_outputs = mlmodel.predict(coreml_inputs)
1408
+
1409
+ # Output configuration
1410
+ output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"]
1411
+
1412
+ # Tolerances for real image validation
1413
+ tolerances = {
1414
+ "mean_vectors_3d_positions": 1.2,
1415
+ "singular_values_scales": 0.01,
1416
+ "colors_rgb_linear": 0.01,
1417
+ "opacities_alpha_channel": 0.05,
1418
+ "quaternions_rotations": 5.0,
1419
+ }
1420
+
1421
+ # Use provided validator or create default
1422
+ if quat_validator is None:
1423
+ quat_validator = QuaternionValidator()
1424
+
1425
+ # Log input image statistics
1426
+ LOGGER.info(f"Image: {image_path.name}, shape: {test_image.shape}, range: [{test_image.min():.4f}, {test_image.max():.4f}]")
1427
+
1428
+ # Collect validation results
1429
+ all_passed = True
1430
+ validation_results = []
1431
+
1432
+ for i, name in enumerate(output_names):
1433
+ pt_output = pt_outputs[i].numpy()
1434
+
1435
+ # Find matching Core ML output
1436
+ coreml_key = None
1437
+ if name in coreml_outputs:
1438
+ coreml_key = name
1439
+ else:
1440
+ for key in coreml_outputs:
1441
+ base_name = name.split('_')[0]
1442
+ if base_name in key.lower():
1443
+ coreml_key = key
1444
+ break
1445
+ if coreml_key is None:
1446
+ coreml_key = list(coreml_outputs.keys())[i]
1447
+
1448
+ coreml_output = coreml_outputs[coreml_key]
1449
+ result = {"output": name, "passed": True, "failure_reason": ""}
1450
+
1451
+ if name == "quaternions_rotations":
1452
+ # Use QuaternionValidator
1453
+ quat_result = quat_validator.validate(pt_output, coreml_output, image_name=image_path.name)
1454
+
1455
+ result.update({
1456
+ "max_diff": f"{quat_result['stats']['max']:.6f}",
1457
+ "mean_diff": f"{quat_result['stats']['mean']:.6f}",
1458
+ "p99_diff": f"{quat_result['stats']['p99']:.6f}",
1459
+ "passed": quat_result["passed"],
1460
+ "failure_reason": "; ".join(quat_result["failure_reasons"]) if quat_result["failure_reasons"] else "",
1461
+ })
1462
+
1463
+ if not quat_result["passed"]:
1464
+ all_passed = False
1465
+ else:
1466
+ diff = np.abs(pt_output - coreml_output)
1467
+ output_tolerance = tolerances.get(name, 0.01)
1468
+ max_diff = np.max(diff)
1469
+
1470
+ result.update({
1471
+ "max_diff": f"{max_diff:.6f}",
1472
+ "mean_diff": f"{np.mean(diff):.6f}",
1473
+ "p99_diff": f"{np.percentile(diff, 99):.6f}",
1474
+ })
1475
+
1476
+ if max_diff > output_tolerance:
1477
+ result["passed"] = False
1478
+ result["failure_reason"] = f"max diff {max_diff:.6f} > tolerance {output_tolerance:.6f}"
1479
+ all_passed = False
1480
+
1481
+ validation_results.append(result)
1482
+
1483
+ # Output validation results as markdown table
1484
+ LOGGER.info(f"\n### Validation Results: {image_path.name}\n")
1485
+ table = format_validation_table(validation_results, image_path.name, include_image_column=False)
1486
+ LOGGER.info(table)
1487
+ LOGGER.info("")
1488
+
1489
+ return all_passed
1490
+
1491
+
1492
+ def main():
1493
+ """Main conversion script."""
1494
+ parser = argparse.ArgumentParser(
1495
+ description="Convert SHARP PyTorch model to Core ML format"
1496
+ )
1497
+ parser.add_argument(
1498
+ "-c", "--checkpoint",
1499
+ type=Path,
1500
+ default=None,
1501
+ help="Path to PyTorch checkpoint. Downloads default if not provided.",
1502
+ )
1503
+ parser.add_argument(
1504
+ "-o", "--output",
1505
+ type=Path,
1506
+ default=Path("sharp.mlpackage"),
1507
+ help="Output path for Core ML model (default: sharp.mlpackage)",
1508
+ )
1509
+ parser.add_argument(
1510
+ "--height",
1511
+ type=int,
1512
+ default=1536,
1513
+ help="Input image height (default: 1536)",
1514
+ )
1515
+ parser.add_argument(
1516
+ "--width",
1517
+ type=int,
1518
+ default=1536,
1519
+ help="Input image width (default: 1536)",
1520
+ )
1521
+ parser.add_argument(
1522
+ "--precision",
1523
+ choices=["float16", "float32"],
1524
+ default="float32",
1525
+ help="Compute precision (default: float32)",
1526
+ )
1527
+ parser.add_argument(
1528
+ "--validate",
1529
+ action="store_true",
1530
+ help="Validate Core ML model against PyTorch",
1531
+ )
1532
+ parser.add_argument(
1533
+ "-v", "--verbose",
1534
+ action="store_true",
1535
+ help="Enable verbose logging",
1536
+ )
1537
+ parser.add_argument(
1538
+ "--input-image",
1539
+ type=Path,
1540
+ default=None,
1541
+ action="append",
1542
+ help="Path to input image for validation (can be specified multiple times, requires --validate)",
1543
+ )
1544
+ parser.add_argument(
1545
+ "--tolerance-mean",
1546
+ type=float,
1547
+ default=None,
1548
+ help="Custom mean angular tolerance in degrees (default: 0.01 for random, 0.1 for images)",
1549
+ )
1550
+ parser.add_argument(
1551
+ "--tolerance-p99",
1552
+ type=float,
1553
+ default=None,
1554
+ help="Custom P99 angular tolerance in degrees (default: 0.5 for random, 1.0 for images)",
1555
+ )
1556
+ parser.add_argument(
1557
+ "--tolerance-max",
1558
+ type=float,
1559
+ default=None,
1560
+ help="Custom max angular tolerance in degrees (default: 15.0)",
1561
+ )
1562
+
1563
+ args = parser.parse_args()
1564
+
1565
+ # Configure logging
1566
+ logging.basicConfig(
1567
+ level=logging.DEBUG if args.verbose else logging.INFO,
1568
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
1569
+ )
1570
+
1571
+ # Load PyTorch model
1572
+ LOGGER.info("Loading SHARP model...")
1573
+ predictor = load_sharp_model(args.checkpoint)
1574
+
1575
+ # Setup conversion parameters
1576
+ input_shape = (args.height, args.width)
1577
+ precision = ct.precision.FLOAT16 if args.precision == "float16" else ct.precision.FLOAT32
1578
+
1579
+ # Convert to Core ML
1580
+ LOGGER.info("Converting using direct tracing...")
1581
+ mlmodel = convert_to_coreml(
1582
+ predictor,
1583
+ args.output,
1584
+ input_shape=input_shape,
1585
+ compute_precision=precision,
1586
+ )
1587
+
1588
+ LOGGER.info(f"Core ML model saved to {args.output}")
1589
+
1590
+ # Validate if requested
1591
+ if args.validate:
1592
+ if args.input_image:
1593
+ # Validate with one or more real input images
1594
+ validation_passed = validate_with_image_set(mlmodel, predictor, args.input_image, input_shape)
1595
+ else:
1596
+ # Validate with random input (default behavior)
1597
+ # Build custom angular tolerances from CLI args
1598
+ angular_tolerances = None
1599
+ if args.tolerance_mean or args.tolerance_p99 or args.tolerance_max:
1600
+ angular_tolerances = {
1601
+ "mean": args.tolerance_mean if args.tolerance_mean else 0.01,
1602
+ "p99": args.tolerance_p99 if args.tolerance_p99 else 0.5,
1603
+ "p99_9": 2.0,
1604
+ "max": args.tolerance_max if args.tolerance_max else 15.0,
1605
+ }
1606
+ validation_passed = validate_coreml_model(mlmodel, predictor, input_shape, angular_tolerances=angular_tolerances)
1607
+
1608
+ if validation_passed:
1609
+ LOGGER.info("✓ Validation passed!")
1610
+ else:
1611
+ LOGGER.error("✗ Validation failed!")
1612
+ return 1
1613
+
1614
+ LOGGER.info("Conversion complete!")
1615
+ return 0
1616
+
1617
+
1618
+ if __name__ == "__main__":
1619
+ exit(main())
1620
+ exit(main())
sharp.mlpackage/Data/com.apple.CoreML/model.mlmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e2b156a2a72ad6f86da86b9100b13007b0d343bbd654fba8d65bee66553f2f1
3
+ size 938769
sharp.mlpackage/Data/com.apple.CoreML/weights/weight.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9b2b162a556856468c4602aa50676ccbf638b7eb714e807e403d6ac0fa99bce
3
+ size 2672576384
sharp.mlpackage/Manifest.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fileFormatVersion": "1.0.0",
3
+ "itemInfoEntries": {
4
+ "655381FB-8159-4BD7-A64E-7B14F30B787E": {
5
+ "author": "com.apple.CoreML",
6
+ "description": "CoreML Model Weights",
7
+ "name": "weights",
8
+ "path": "com.apple.CoreML/weights"
9
+ },
10
+ "A0921877-4847-4CCE-937D-414310330106": {
11
+ "author": "com.apple.CoreML",
12
+ "description": "CoreML Model Specification",
13
+ "name": "model.mlmodel",
14
+ "path": "com.apple.CoreML/model.mlmodel"
15
+ }
16
+ },
17
+ "rootModelIdentifier": "A0921877-4847-4CCE-937D-414310330106"
18
+ }
sharp.swift ADDED
@@ -0,0 +1,765 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //
2
+ // SHARPModelRunner.swift
3
+ // SHARP Model Inference and PLY Export
4
+ //
5
+ // Loads a SHARP Core ML model, runs inference on an image,
6
+ // and saves the 3D Gaussian splat output as a PLY file.
7
+ //
8
+ // Usage:
9
+ // swiftc -O -o sharp_runner sharp.swift -framework CoreML -framework CoreImage -framework AppKit
10
+ // ./sharp_runner sharp.mlpackage test.png output.ply -d 0.5
11
+
12
+ import Foundation
13
+ import CoreML
14
+ import CoreImage
15
+ import AppKit // For NSImage on macOS; use UIKit for iOS
16
+
17
+ // MARK: - Gaussians3D Structure
18
+
19
+ /// Represents the output of the SHARP model - a collection of 3D Gaussians
20
+ struct Gaussians3D {
21
+ let meanVectors: MLMultiArray // Shape: (1, N, 3) - 3D positions
22
+ let singularValues: MLMultiArray // Shape: (1, N, 3) - scales
23
+ let quaternions: MLMultiArray // Shape: (1, N, 4) - rotations
24
+ let colors: MLMultiArray // Shape: (1, N, 3) - RGB colors (linear)
25
+ let opacities: MLMultiArray // Shape: (1, N) - opacity values
26
+
27
+ var count: Int {
28
+ return meanVectors.shape[1].intValue
29
+ }
30
+
31
+ /// Compute importance scores for each Gaussian.
32
+ /// Higher scores = more important (larger and more opaque).
33
+ func computeImportanceScores() -> [Float] {
34
+ let n = count
35
+ var scores = [Float](repeating: 0, count: n)
36
+
37
+ let scalePtr = singularValues.dataPointer.assumingMemoryBound(to: Float.self)
38
+ let opacityPtr = opacities.dataPointer.assumingMemoryBound(to: Float.self)
39
+
40
+ for i in 0..<n {
41
+ // Sum of log scales (singular values are already in linear space, not log)
42
+ // To match Python: scales = exp(scale_0 + scale_1 + scale_2)
43
+ // But our singularValues are already exp(log_scale), so we need log them first
44
+ let s0 = scalePtr[i * 3 + 0]
45
+ let s1 = scalePtr[i * 3 + 1]
46
+ let s2 = scalePtr[i * 3 + 2]
47
+
48
+ // Product of scales (equivalent to exp(log_s0 + log_s1 + log_s2))
49
+ let scaleProduct = s0 * s1 * s2
50
+
51
+ // Opacity is already in [0, 1] range (after sigmoid in model)
52
+ let opacity = opacityPtr[i]
53
+
54
+ scores[i] = scaleProduct * opacity
55
+ }
56
+
57
+ return scores
58
+ }
59
+
60
+ /// Decimate the Gaussians by keeping only a fraction based on importance.
61
+ /// Returns indices of Gaussians to keep, sorted for spatial coherence.
62
+ func decimationIndices(keepRatio: Float) -> [Int] {
63
+ let n = count
64
+ let keepCount = max(1, Int(Float(n) * keepRatio))
65
+
66
+ // Compute importance scores
67
+ let scores = computeImportanceScores()
68
+
69
+ // Create array of (index, score) pairs and sort by score descending
70
+ var indexedScores = scores.enumerated().map { ($0.offset, $0.element) }
71
+ indexedScores.sort { $0.1 > $1.1 }
72
+
73
+ // Get top keepCount indices
74
+ var keepIndices = indexedScores.prefix(keepCount).map { $0.0 }
75
+
76
+ // Sort indices to maintain spatial coherence
77
+ keepIndices.sort()
78
+
79
+ return keepIndices
80
+ }
81
+ }
82
+
83
+ // MARK: - Color Space Utilities
84
+
85
+ /// Convert linear RGB to sRGB color space
86
+ func linearRGBToSRGB(_ linear: Float) -> Float {
87
+ if linear <= 0.0031308 {
88
+ return linear * 12.92
89
+ } else {
90
+ return 1.055 * pow(linear, 1.0 / 2.4) - 0.055
91
+ }
92
+ }
93
+
94
+ /// Convert RGB to degree-0 spherical harmonics
95
+ func rgbToSphericalHarmonics(_ rgb: Float) -> Float {
96
+ let coeffDegree0 = sqrt(1.0 / (4.0 * Float.pi))
97
+ return (rgb - 0.5) / coeffDegree0
98
+ }
99
+
100
+ /// Inverse sigmoid function
101
+ func inverseSigmoid(_ x: Float) -> Float {
102
+ let clamped = min(max(x, 1e-6), 1.0 - 1e-6)
103
+ return log(clamped / (1.0 - clamped))
104
+ }
105
+
106
+ // MARK: - SHARP Model Wrapper
107
+
108
+ class SHARPModelRunner {
109
+ private let model: MLModel
110
+ private let inputHeight: Int
111
+ private let inputWidth: Int
112
+
113
+ init(modelPath: URL, inputHeight: Int = 1536, inputWidth: Int = 1536) throws {
114
+ let config = MLModelConfiguration()
115
+ config.computeUnits = .all
116
+
117
+ // Compile the model if needed
118
+ let compiledModelURL = try SHARPModelRunner.compileModelIfNeeded(at: modelPath)
119
+
120
+ self.model = try MLModel(contentsOf: compiledModelURL, configuration: config)
121
+ self.inputHeight = inputHeight
122
+ self.inputWidth = inputWidth
123
+
124
+ // Print model description for debugging
125
+ print("Model inputs: \(model.modelDescription.inputDescriptionsByName.keys.joined(separator: ", "))")
126
+ print("Model outputs: \(model.modelDescription.outputDescriptionsByName.keys.joined(separator: ", "))")
127
+ }
128
+
129
+ /// Compile the model if it's not already compiled
130
+ private static func compileModelIfNeeded(at modelPath: URL) throws -> URL {
131
+ let fileManager = FileManager.default
132
+ let pathExtension = modelPath.pathExtension.lowercased()
133
+
134
+ // If already compiled (.mlmodelc), return as-is
135
+ if pathExtension == "mlmodelc" {
136
+ print("Model is already compiled.")
137
+ return modelPath
138
+ }
139
+
140
+ // Check if it's an .mlpackage or .mlmodel that needs compilation
141
+ guard pathExtension == "mlpackage" || pathExtension == "mlmodel" else {
142
+ throw NSError(domain: "SHARPModelRunner", code: 10,
143
+ userInfo: [NSLocalizedDescriptionKey: "Unsupported model format: \(pathExtension).Use .mlpackage, .mlmodel, or .mlmodelc"])
144
+ }
145
+
146
+ // Create a cache directory for compiled models
147
+ let cacheDir = fileManager.temporaryDirectory.appendingPathComponent("SHARPModelCache")
148
+ try? fileManager.createDirectory(at: cacheDir, withIntermediateDirectories: true)
149
+
150
+ // Generate a unique name for the compiled model based on the source path
151
+ let modelName = modelPath.deletingPathExtension().lastPathComponent
152
+ let compiledPath = cacheDir.appendingPathComponent("\(modelName).mlmodelc")
153
+
154
+ // Check if we have a cached compiled version
155
+ if fileManager.fileExists(atPath: compiledPath.path) {
156
+ // Verify the cached version is newer than the source
157
+ let sourceAttrs = try fileManager.attributesOfItem(atPath: modelPath.path)
158
+ let cachedAttrs = try fileManager.attributesOfItem(atPath: compiledPath.path)
159
+
160
+ if let sourceDate = sourceAttrs[.modificationDate] as? Date,
161
+ let cachedDate = cachedAttrs[.modificationDate] as? Date,
162
+ cachedDate >= sourceDate {
163
+ print("Using cached compiled model at \(compiledPath.path)")
164
+ return compiledPath
165
+ } else {
166
+ // Source is newer, remove old cached version
167
+ try? fileManager.removeItem(at: compiledPath)
168
+ }
169
+ }
170
+
171
+ // Compile the model
172
+ print("Compiling model (this may take a moment)...")
173
+ let startTime = CFAbsoluteTimeGetCurrent()
174
+
175
+ let temporaryCompiledURL = try MLModel.compileModel(at: modelPath)
176
+
177
+ let compileTime = CFAbsoluteTimeGetCurrent() - startTime
178
+ print("✓ Model compiled in \(String(format: "%.1f", compileTime))s")
179
+
180
+ // Move to our cache directory
181
+ try? fileManager.removeItem(at: compiledPath)
182
+ try fileManager.moveItem(at: temporaryCompiledURL, to: compiledPath)
183
+
184
+ print("Compiled model cached at \(compiledPath.path)")
185
+ return compiledPath
186
+ }
187
+
188
+ /// Load and preprocess an image for model input
189
+ func preprocessImage(at imagePath: URL) throws -> MLMultiArray {
190
+ guard let nsImage = NSImage(contentsOf: imagePath) else {
191
+ throw NSError(domain: "SHARPModelRunner", code: 1,
192
+ userInfo: [NSLocalizedDescriptionKey: "Failed to load image from \(imagePath.path)"])
193
+ }
194
+
195
+ guard let cgImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil) else {
196
+ throw NSError(domain: "SHARPModelRunner", code: 2,
197
+ userInfo: [NSLocalizedDescriptionKey: "Failed to convert to CGImage"])
198
+ }
199
+
200
+ // Create CIImage and resize
201
+ let ciImage = CIImage(cgImage: cgImage)
202
+ let context = CIContext()
203
+
204
+ // Scale to target size
205
+ let scaleX = CGFloat(inputWidth) / ciImage.extent.width
206
+ let scaleY = CGFloat(inputHeight) / ciImage.extent.height
207
+ let scaledImage = ciImage.transformed(by: CGAffineTransform(scaleX: scaleX, y: scaleY))
208
+
209
+ // Render to bitmap
210
+ guard let resizedCGImage = context.createCGImage(scaledImage, from: CGRect(x: 0, y: 0,
211
+ width: inputWidth,
212
+ height: inputHeight)) else {
213
+ throw NSError(domain: "SHARPModelRunner", code: 3,
214
+ userInfo: [NSLocalizedDescriptionKey: "Failed to resize image"])
215
+ }
216
+
217
+ // Convert to MLMultiArray (1, 3, H, W) normalized to [0, 1]
218
+ let imageArray = try MLMultiArray(shape: [1, 3, NSNumber(value: inputHeight), NSNumber(value: inputWidth)],
219
+ dataType: .float32)
220
+
221
+ let width = resizedCGImage.width
222
+ let height = resizedCGImage.height
223
+ let bytesPerPixel = 4
224
+ let bytesPerRow = bytesPerPixel * width
225
+ var pixelData = [UInt8](repeating: 0, count: height * bytesPerRow)
226
+
227
+ let colorSpace = CGColorSpaceCreateDeviceRGB()
228
+ guard let cgContext = CGContext(data: &pixelData,
229
+ width: width,
230
+ height: height,
231
+ bitsPerComponent: 8,
232
+ bytesPerRow: bytesPerRow,
233
+ space: colorSpace,
234
+ bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue) else {
235
+ throw NSError(domain: "SHARPModelRunner", code: 4,
236
+ userInfo: [NSLocalizedDescriptionKey: "Failed to create bitmap context"])
237
+ }
238
+
239
+ cgContext.draw(resizedCGImage, in: CGRect(x: 0, y: 0, width: width, height: height))
240
+
241
+ // Copy pixel data to MLMultiArray in CHW format
242
+ // Use pointer access for better performance
243
+ let ptr = imageArray.dataPointer.assumingMemoryBound(to: Float.self)
244
+ let channelStride = inputHeight * inputWidth
245
+
246
+ for y in 0..<height {
247
+ for x in 0..<width {
248
+ let pixelIndex = y * bytesPerRow + x * bytesPerPixel
249
+ let r = Float(pixelData[pixelIndex]) / 255.0
250
+ let g = Float(pixelData[pixelIndex + 1]) / 255.0
251
+ let b = Float(pixelData[pixelIndex + 2]) / 255.0
252
+
253
+ let spatialIndex = y * inputWidth + x
254
+ ptr[0 * channelStride + spatialIndex] = r
255
+ ptr[1 * channelStride + spatialIndex] = g
256
+ ptr[2 * channelStride + spatialIndex] = b
257
+ }
258
+ }
259
+
260
+ return imageArray
261
+ }
262
+
263
+ /// Run inference on the model
264
+ func predict(image: MLMultiArray, focalLengthPx: Float) throws -> Gaussians3D {
265
+ // Calculate disparity factor: focal_length / image_width
266
+ let disparityFactor = focalLengthPx / Float(inputWidth)
267
+
268
+ // Create disparity factor input
269
+ let disparityArray = try MLMultiArray(shape: [1], dataType: .float32)
270
+ disparityArray[0] = NSNumber(value: disparityFactor)
271
+
272
+ // Create feature provider
273
+ let inputFeatures = try MLDictionaryFeatureProvider(dictionary: [
274
+ "image": MLFeatureValue(multiArray: image),
275
+ "disparity_factor": MLFeatureValue(multiArray: disparityArray)
276
+ ])
277
+
278
+ // Run prediction
279
+ let output = try model.prediction(from: inputFeatures)
280
+
281
+ // Try to find outputs by checking available names
282
+ let outputNames = Array(model.modelDescription.outputDescriptionsByName.keys)
283
+
284
+ // Helper function to find output by partial name match
285
+ func findOutput(containing keywords: [String]) -> MLMultiArray? {
286
+ for name in outputNames {
287
+ let lowercaseName = name.lowercased()
288
+ for keyword in keywords {
289
+ if lowercaseName.contains(keyword.lowercased()) {
290
+ return output.featureValue(for: name)?.multiArrayValue
291
+ }
292
+ }
293
+ }
294
+ return nil
295
+ }
296
+
297
+ // Try to match outputs - first try exact names, then partial matches
298
+ let meanVectors = output.featureValue(for: "mean_vectors_3d_positions")?.multiArrayValue
299
+ ?? findOutput(containing: ["mean", "position", "xyz"])
300
+
301
+ let singularValues = output.featureValue(for: "singular_values_scales")?.multiArrayValue
302
+ ?? findOutput(containing: ["singular", "scale"])
303
+
304
+ let quaternions = output.featureValue(for: "quaternions_rotations")?.multiArrayValue
305
+ ?? findOutput(containing: ["quaternion", "rotation", "rot"])
306
+
307
+ let colors = output.featureValue(for: "colors_rgb_linear")?.multiArrayValue
308
+ ?? findOutput(containing: ["color", "rgb"])
309
+
310
+ let opacities = output.featureValue(for: "opacities_alpha_channel")?.multiArrayValue
311
+ ?? findOutput(containing: ["opacity", "alpha"])
312
+
313
+ // If we still couldn't find outputs, try by index order
314
+ if meanVectors == nil || singularValues == nil || quaternions == nil || colors == nil || opacities == nil {
315
+ print("Warning: Could not match all outputs by name.Available outputs: \(outputNames)")
316
+
317
+ // Try to get outputs by index if we have exactly 5
318
+ if outputNames.count >= 5 {
319
+ let sortedNames = outputNames.sorted()
320
+ guard let mv = output.featureValue(for: sortedNames[0])?.multiArrayValue,
321
+ let sv = output.featureValue(for: sortedNames[1])?.multiArrayValue,
322
+ let q = output.featureValue(for: sortedNames[2])?.multiArrayValue,
323
+ let c = output.featureValue(for: sortedNames[3])?.multiArrayValue,
324
+ let o = output.featureValue(for: sortedNames[4])?.multiArrayValue else {
325
+ throw NSError(domain: "SHARPModelRunner", code: 5,
326
+ userInfo: [NSLocalizedDescriptionKey: "Failed to extract model outputs. Available: \(outputNames)"])
327
+ }
328
+
329
+ print("Using outputs by sorted order: \(sortedNames)")
330
+ return Gaussians3D(
331
+ meanVectors: mv,
332
+ singularValues: sv,
333
+ quaternions: q,
334
+ colors: c,
335
+ opacities: o
336
+ )
337
+ }
338
+
339
+ throw NSError(domain: "SHARPModelRunner", code: 5,
340
+ userInfo: [NSLocalizedDescriptionKey: "Failed to extract model outputs.Available: \(outputNames)"])
341
+ }
342
+
343
+ return Gaussians3D(
344
+ meanVectors: meanVectors!,
345
+ singularValues: singularValues!,
346
+ quaternions: quaternions!,
347
+ colors: colors!,
348
+ opacities: opacities!
349
+ )
350
+ }
351
+
352
+ /// Save Gaussians to PLY file (matching Python save_ply format exactly)
353
+ /// - Parameters:
354
+ /// - gaussians: The Gaussians to save
355
+ /// - focalLengthPx: Focal length in pixels
356
+ /// - imageShape: Image dimensions (height, width)
357
+ /// - outputPath: Output file path
358
+ /// - decimation: Optional decimation ratio (0.0-1.0).1.0 = keep all, 0.5 = keep 50%
359
+ func savePLY(gaussians: Gaussians3D,
360
+ focalLengthPx: Float,
361
+ imageShape: (height: Int, width: Int),
362
+ to outputPath: URL,
363
+ decimation: Float = 1.0) throws {
364
+
365
+ let imageHeight = imageShape.height
366
+ let imageWidth = imageShape.width
367
+
368
+ // Determine which indices to keep based on decimation
369
+ let keepIndices: [Int]
370
+ let originalCount = gaussians.count
371
+
372
+ if decimation < 1.0 {
373
+ keepIndices = gaussians.decimationIndices(keepRatio: decimation)
374
+ print("Decimating: keeping \(keepIndices.count) of \(originalCount) Gaussians (\(String(format: "%.1f", decimation * 100))%)")
375
+ } else {
376
+ keepIndices = Array(0..<originalCount)
377
+ }
378
+
379
+ let numGaussians = keepIndices.count
380
+
381
+ var fileContent = Data()
382
+
383
+ // Helper to append string
384
+ func appendString(_ str: String) {
385
+ fileContent.append(str.data(using: .ascii)!)
386
+ }
387
+
388
+ // Helper to append float32 in little-endian
389
+ func appendFloat32(_ value: Float) {
390
+ var v = value
391
+ fileContent.append(Data(bytes: &v, count: 4))
392
+ }
393
+
394
+ // Helper to append int32 in little-endian
395
+ func appendInt32(_ value: Int32) {
396
+ var v = value
397
+ fileContent.append(Data(bytes: &v, count: 4))
398
+ }
399
+
400
+ // Helper to append uint32 in little-endian
401
+ func appendUInt32(_ value: UInt32) {
402
+ var v = value
403
+ fileContent.append(Data(bytes: &v, count: 4))
404
+ }
405
+
406
+ // Helper to append uint8
407
+ func appendUInt8(_ value: UInt8) {
408
+ var v = value
409
+ fileContent.append(Data(bytes: &v, count: 1))
410
+ }
411
+
412
+ // ===== PLY Header =====
413
+ appendString("ply\n")
414
+ appendString("format binary_little_endian 1.0\n")
415
+
416
+ // Vertex element
417
+ appendString("element vertex \(numGaussians)\n")
418
+ appendString("property float x\n")
419
+ appendString("property float y\n")
420
+ appendString("property float z\n")
421
+ appendString("property float f_dc_0\n")
422
+ appendString("property float f_dc_1\n")
423
+ appendString("property float f_dc_2\n")
424
+ appendString("property float opacity\n")
425
+ appendString("property float scale_0\n")
426
+ appendString("property float scale_1\n")
427
+ appendString("property float scale_2\n")
428
+ appendString("property float rot_0\n")
429
+ appendString("property float rot_1\n")
430
+ appendString("property float rot_2\n")
431
+ appendString("property float rot_3\n")
432
+
433
+ // Extrinsic element (16 floats for 4x4 identity matrix)
434
+ appendString("element extrinsic 16\n")
435
+ appendString("property float extrinsic\n")
436
+
437
+ // Intrinsic element (9 floats for 3x3 matrix)
438
+ appendString("element intrinsic 9\n")
439
+ appendString("property float intrinsic\n")
440
+
441
+ // Image size element
442
+ appendString("element image_size 2\n")
443
+ appendString("property uint image_size\n")
444
+
445
+ // Frame element
446
+ appendString("element frame 2\n")
447
+ appendString("property int frame\n")
448
+
449
+ // Disparity element
450
+ appendString("element disparity 2\n")
451
+ appendString("property float disparity\n")
452
+
453
+ // Color space element
454
+ appendString("element color_space 1\n")
455
+ appendString("property uchar color_space\n")
456
+
457
+ // Version element
458
+ appendString("element version 3\n")
459
+ appendString("property uchar version\n")
460
+
461
+ appendString("end_header\n")
462
+
463
+ // ===== Vertex Data =====
464
+ // Compute disparity quantiles for later
465
+ var disparities: [Float] = []
466
+
467
+ // Get pointers for faster access
468
+ let meanPtr = gaussians.meanVectors.dataPointer.assumingMemoryBound(to: Float.self)
469
+ let scalePtr = gaussians.singularValues.dataPointer.assumingMemoryBound(to: Float.self)
470
+ let quatPtr = gaussians.quaternions.dataPointer.assumingMemoryBound(to: Float.self)
471
+ let colorPtr = gaussians.colors.dataPointer.assumingMemoryBound(to: Float.self)
472
+ let opacityPtr = gaussians.opacities.dataPointer.assumingMemoryBound(to: Float.self)
473
+
474
+ for i in keepIndices {
475
+ // Position (x, y, z)
476
+ let x = meanPtr[i * 3 + 0]
477
+ let y = meanPtr[i * 3 + 1]
478
+ let z = meanPtr[i * 3 + 2]
479
+ appendFloat32(x)
480
+ appendFloat32(y)
481
+ appendFloat32(z)
482
+
483
+ // Compute disparity for quantiles
484
+ if z > 1e-6 {
485
+ disparities.append(1.0 / z)
486
+ }
487
+
488
+ // Colors: Convert linearRGB -> sRGB -> spherical harmonics
489
+ // Model outputs linearRGB colors for proper alpha blending
490
+ // We convert to sRGB for compatibility with public renderers
491
+ let colorR = colorPtr[i * 3 + 0]
492
+ let colorG = colorPtr[i * 3 + 1]
493
+ let colorB = colorPtr[i * 3 + 2]
494
+
495
+ let srgbR = linearRGBToSRGB(colorR)
496
+ let srgbG = linearRGBToSRGB(colorG)
497
+ let srgbB = linearRGBToSRGB(colorB)
498
+
499
+ let sh0 = rgbToSphericalHarmonics(srgbR)
500
+ let sh1 = rgbToSphericalHarmonics(srgbG)
501
+ let sh2 = rgbToSphericalHarmonics(srgbB)
502
+
503
+ appendFloat32(sh0)
504
+ appendFloat32(sh1)
505
+ appendFloat32(sh2)
506
+
507
+ // Opacity: Convert to logits using inverse sigmoid
508
+ let opacity = opacityPtr[i]
509
+ let opacityLogit = inverseSigmoid(opacity)
510
+ appendFloat32(opacityLogit)
511
+
512
+ // Scales: Convert to log scale
513
+ let scale0 = scalePtr[i * 3 + 0]
514
+ let scale1 = scalePtr[i * 3 + 1]
515
+ let scale2 = scalePtr[i * 3 + 2]
516
+
517
+ appendFloat32(log(max(scale0, 1e-10)))
518
+ appendFloat32(log(max(scale1, 1e-10)))
519
+ appendFloat32(log(max(scale2, 1e-10)))
520
+
521
+ // Quaternions (w, x, y, z)
522
+ let q0 = quatPtr[i * 4 + 0]
523
+ let q1 = quatPtr[i * 4 + 1]
524
+ let q2 = quatPtr[i * 4 + 2]
525
+ let q3 = quatPtr[i * 4 + 3]
526
+
527
+ appendFloat32(q0)
528
+ appendFloat32(q1)
529
+ appendFloat32(q2)
530
+ appendFloat32(q3)
531
+ }
532
+
533
+ // ===== Extrinsic Data (4x4 identity matrix) =====
534
+ let identity: [Float] = [
535
+ 1, 0, 0, 0,
536
+ 0, 1, 0, 0,
537
+ 0, 0, 1, 0,
538
+ 0, 0, 0, 1
539
+ ]
540
+ for val in identity {
541
+ appendFloat32(val)
542
+ }
543
+
544
+ // ===== Intrinsic Data (3x3 matrix) =====
545
+ let intrinsic: [Float] = [
546
+ focalLengthPx, 0, Float(imageWidth) * 0.5,
547
+ 0, focalLengthPx, Float(imageHeight) * 0.5,
548
+ 0, 0, 1
549
+ ]
550
+ for val in intrinsic {
551
+ appendFloat32(val)
552
+ }
553
+
554
+ // ===== Image Size Data =====
555
+ appendUInt32(UInt32(imageWidth))
556
+ appendUInt32(UInt32(imageHeight))
557
+
558
+ // ===== Frame Data =====
559
+ appendInt32(1) // Number of frames
560
+ appendInt32(Int32(numGaussians)) // Particles per frame
561
+
562
+ // ===== Disparity Data (quantiles) =====
563
+ disparities.sort()
564
+ let q10Index = Int(Float(disparities.count) * 0.1)
565
+ let q90Index = Int(Float(disparities.count) * 0.9)
566
+ let disparity10 = disparities.isEmpty ? 0.0 : disparities[min(q10Index, disparities.count - 1)]
567
+ let disparity90 = disparities.isEmpty ? 1.0 : disparities[min(q90Index, disparities.count - 1)]
568
+ appendFloat32(disparity10)
569
+ appendFloat32(disparity90)
570
+
571
+ // ===== Color Space Data (sRGB = 1) =====
572
+ appendUInt8(1)
573
+
574
+ // ===== Version Data =====
575
+ appendUInt8(1) // Major
576
+ appendUInt8(5) // Minor
577
+ appendUInt8(0) // Patch
578
+
579
+ // Write to file
580
+ try fileContent.write(to: outputPath)
581
+
582
+ print("✓ Saved PLY with \(numGaussians) Gaussians to \(outputPath.path)")
583
+ }
584
+ }
585
+
586
+ // MARK: - Command Line Argument Parsing
587
+
588
+ struct CommandLineArgs {
589
+ let modelPath: URL
590
+ let imagePath: URL
591
+ let outputPath: URL
592
+ let focalLength: Float
593
+ let decimation: Float
594
+
595
+ static func parse() -> CommandLineArgs? {
596
+ let args = CommandLine.arguments
597
+
598
+ var modelPath: URL?
599
+ var imagePath: URL?
600
+ var outputPath: URL?
601
+ var focalLength: Float = 1536.0
602
+ var decimation: Float = 1.0
603
+
604
+ var i = 1
605
+ while i < args.count {
606
+ let arg = args[i]
607
+
608
+ switch arg {
609
+ case "-m", "--model":
610
+ i += 1
611
+ if i < args.count {
612
+ modelPath = URL(fileURLWithPath: args[i])
613
+ }
614
+
615
+ case "-i", "--input":
616
+ i += 1
617
+ if i < args.count {
618
+ imagePath = URL(fileURLWithPath: args[i])
619
+ }
620
+
621
+ case "-o", "--output":
622
+ i += 1
623
+ if i < args.count {
624
+ outputPath = URL(fileURLWithPath: args[i])
625
+ }
626
+
627
+ case "-f", "--focal-length":
628
+ i += 1
629
+ if i < args.count {
630
+ focalLength = Float(args[i]) ?? 1536.0
631
+ }
632
+
633
+ case "-d", "--decimation":
634
+ i += 1
635
+ if i < args.count {
636
+ if let value = Float(args[i]) {
637
+ // Accept both percentage (0-100) and ratio (0-1)
638
+ if value > 1.0 {
639
+ decimation = value / 100.0
640
+ } else {
641
+ decimation = value
642
+ }
643
+ decimation = max(0.01, min(1.0, decimation))
644
+ }
645
+ }
646
+
647
+ case "-h", "--help":
648
+ printUsage()
649
+ return nil
650
+
651
+ default:
652
+ // Handle positional arguments for backward compatibility
653
+ if modelPath == nil {
654
+ modelPath = URL(fileURLWithPath: arg)
655
+ } else if imagePath == nil {
656
+ imagePath = URL(fileURLWithPath: arg)
657
+ } else if outputPath == nil {
658
+ outputPath = URL(fileURLWithPath: arg)
659
+ } else if focalLength == 1536.0 {
660
+ focalLength = Float(arg) ?? 1536.0
661
+ }
662
+ }
663
+
664
+ i += 1
665
+ }
666
+
667
+ guard let model = modelPath, let image = imagePath, let output = outputPath else {
668
+ printUsage()
669
+ return nil
670
+ }
671
+
672
+ return CommandLineArgs(
673
+ modelPath: model,
674
+ imagePath: image,
675
+ outputPath: output,
676
+ focalLength: focalLength,
677
+ decimation: decimation
678
+ )
679
+ }
680
+
681
+ static func printUsage() {
682
+ let execName = CommandLine.arguments[0].components(separatedBy: "/").last ?? "sharp_runner"
683
+ print("""
684
+ Usage: \(execName) [OPTIONS] <model> <input_image> <output.ply>
685
+
686
+ SHARP Model Inference - Generate 3D Gaussian Splats from a single image
687
+
688
+ Arguments:
689
+ model Path to the SHARP Core ML model (.mlpackage, .mlmodel, or .mlmodelc)
690
+ input_image Path to input image (PNG, JPEG, etc.)
691
+ output.ply Path for output PLY file
692
+
693
+ Options:
694
+ -m, --model PATH Path to Core ML model
695
+ -i, --input PATH Path to input image
696
+ -o, --output PATH Path for output PLY file
697
+ -f, --focal-length FLOAT Focal length in pixels (default: 1536)
698
+ -d, --decimation FLOAT Decimation ratio 0.0-1.0 or percentage 1-100 (default: 1.0 = keep all)
699
+ Example: 0.5 or 50 keeps 50% of Gaussians
700
+ -h, --help Show this help message
701
+
702
+ Examples:
703
+ # Basic usage
704
+ \(execName) sharp.mlpackage photo.jpg output.ply
705
+
706
+ # With focal length
707
+ \(execName) sharp.mlpackage photo.jpg output.ply 768
708
+
709
+ # With decimation (keep 50% of points)
710
+ \(execName) -m sharp.mlpackage -i photo.jpg -o output.ply -d 0.5
711
+
712
+ # With decimation as percentage
713
+ \(execName) -m sharp.mlpackage -i photo.jpg -o output.ply -d 25
714
+
715
+ The model will be automatically compiled on first use and cached for subsequent runs.
716
+ Decimation keeps the most important Gaussians based on scale and opacity.
717
+ """)
718
+ }
719
+ }
720
+
721
+ // MARK: - Main Entry Point
722
+
723
+ func main() {
724
+ guard let args = CommandLineArgs.parse() else {
725
+ exit(1)
726
+ }
727
+
728
+ do {
729
+ print("Loading SHARP model from \(args.modelPath.path)...")
730
+ let runner = try SHARPModelRunner(modelPath: args.modelPath)
731
+
732
+ print("Preprocessing image \(args.imagePath.path)...")
733
+ let imageArray = try runner.preprocessImage(at: args.imagePath)
734
+
735
+ print("Running inference...")
736
+ let startTime = CFAbsoluteTimeGetCurrent()
737
+ let gaussians = try runner.predict(image: imageArray, focalLengthPx: args.focalLength)
738
+ let inferenceTime = CFAbsoluteTimeGetCurrent() - startTime
739
+
740
+ print("✓ Generated \(gaussians.count) Gaussians in \(String(format: "%.2f", inferenceTime))s")
741
+
742
+ print("Saving PLY file...")
743
+ try runner.savePLY(
744
+ gaussians: gaussians,
745
+ focalLengthPx: args.focalLength,
746
+ imageShape: (height: 1536, width: 1536),
747
+ to: args.outputPath,
748
+ decimation: args.decimation
749
+ )
750
+
751
+ print("✓ Complete!")
752
+
753
+ } catch {
754
+ print("Error: \(error.localizedDescription)")
755
+ if let nsError = error as NSError? {
756
+ print("Domain: \(nsError.domain), Code: \(nsError.code)")
757
+ if let underlyingError = nsError.userInfo[NSUnderlyingErrorKey] as? Error {
758
+ print("Underlying error: \(underlyingError)")
759
+ }
760
+ }
761
+ exit(1)
762
+ }
763
+ }
764
+
765
+ main()
test.ply ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b08f5a8cc6f1afffae48c257f0bf51b5f66dc0a13ff02aca16fc8ffe0a9d7f4f
3
+ size 33030941
test.png ADDED

Git LFS Details

  • SHA256: eb80679727edd10314845ac4490e886c0f123aebf13680f0a03cd12978997928
  • Pointer size: 132 Bytes
  • Size of remote file: 1.23 MB
viewer.gif ADDED

Git LFS Details

  • SHA256: dc08d861335fcf8f3df546b29a941e67890fd647cfa0cd5c2d2a28691ea7a50f
  • Pointer size: 132 Bytes
  • Size of remote file: 8.76 MB