Kyle Pearson commited on
Commit
5fb2d50
·
1 Parent(s): 430c74c

convert + testing scripts

Browse files
Files changed (2) hide show
  1. convert.py +780 -0
  2. sharp.swift +763 -0
convert.py ADDED
@@ -0,0 +1,780 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Convert SHARP PyTorch model to Core ML .mlmodel format.
2
+
3
+ This script converts the SHARP (Sharp Monocular View Synthesis) model
4
+ from PyTorch (.pt) to Core ML (.mlmodel) format for deployment on Apple devices.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import argparse
10
+ import logging
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+ import coremltools as ct
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ # Import SHARP model components
20
+ from sharp.models import PredictorParams, create_predictor
21
+ from sharp.models.predictor import RGBGaussianPredictor
22
+
23
+ LOGGER = logging.getLogger(__name__)
24
+
25
+ DEFAULT_MODEL_URL = "https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt"
26
+
27
+
28
+ class SafeClamp(nn.Module):
29
+ """Safe clamp operation that avoids tracing issues."""
30
+
31
+ def forward(self, x, min_val=1e-4, max_val=1e4):
32
+ return torch.clamp(x, min=min_val, max=max_val)
33
+
34
+
35
+ class SafeDivision(nn.Module):
36
+ """Safe division that avoids division by zero."""
37
+
38
+ def forward(self, numerator, denominator):
39
+ return numerator / torch.clamp(denominator, min=1e-8)
40
+
41
+
42
+ class SharpModelTraceable(nn.Module):
43
+ """Fully traceable version of SHARP for Core ML conversion.
44
+
45
+ This version removes all dynamic control flow and makes the model
46
+ fully traceable with torch.jit.trace.
47
+ """
48
+
49
+ def __init__(self, predictor: RGBGaussianPredictor):
50
+ """Initialize the traceable wrapper.
51
+
52
+ Args:
53
+ predictor: The SHARP RGBGaussianPredictor model.
54
+ """
55
+ super().__init__()
56
+ # Copy all submodules
57
+ self.init_model = predictor.init_model
58
+ self.feature_model = predictor.feature_model
59
+ self.monodepth_model = predictor.monodepth_model
60
+ self.prediction_head = predictor.prediction_head
61
+ self.gaussian_composer = predictor.gaussian_composer
62
+ self.depth_alignment = predictor.depth_alignment
63
+
64
+ # Replace problematic operations with custom modules
65
+ self.safe_clamp = SafeClamp()
66
+ self.safe_div = SafeDivision()
67
+
68
+ def forward(
69
+ self,
70
+ image: torch.Tensor,
71
+ disparity_factor: torch.Tensor
72
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
73
+ """Run inference with traceable forward pass.
74
+
75
+ Args:
76
+ image: Input image tensor of shape (1, 3, H, W) in range [0, 1].
77
+ disparity_factor: Disparity factor tensor of shape (1,).
78
+
79
+ Returns:
80
+ Tuple of 5 tensors representing 3D Gaussians.
81
+ """
82
+ # Estimate depth using monodepth
83
+ monodepth_output = self.monodepth_model(image)
84
+ monodepth_disparity = monodepth_output.disparity
85
+
86
+ # Convert disparity to depth with higher precision
87
+ # Use tighter clamp bounds and higher precision intermediate computation
88
+ disparity_factor_expanded = disparity_factor[:, None, None, None]
89
+
90
+ # Cast to float64 for more precise division, then back to float32
91
+ disparity_clamped = monodepth_disparity.clamp(min=1e-6, max=1e4)
92
+ monodepth = disparity_factor_expanded.double() / disparity_clamped.double()
93
+ monodepth = monodepth.float()
94
+
95
+ # Apply depth alignment (inference mode)
96
+ monodepth, _ = self.depth_alignment(monodepth, None, monodepth_output.decoder_features)
97
+
98
+ # Initialize gaussians
99
+ init_output = self.init_model(image, monodepth)
100
+
101
+ # Extract features
102
+ image_features = self.feature_model(
103
+ init_output.feature_input,
104
+ encodings=monodepth_output.output_features
105
+ )
106
+
107
+ # Predict deltas
108
+ delta_values = self.prediction_head(image_features)
109
+
110
+ # Compose final gaussians
111
+ gaussians = self.gaussian_composer(
112
+ delta=delta_values,
113
+ base_values=init_output.gaussian_base_values,
114
+ global_scale=init_output.global_scale,
115
+ )
116
+
117
+ # Normalize quaternions for consistent validation and inference
118
+ # This is critical for CoreML conversion accuracy
119
+ quaternions = gaussians.quaternions
120
+
121
+ # Use double precision for quaternion normalization to reduce numerical errors
122
+ quaternions_fp64 = quaternions.double()
123
+ quat_norm_sq = torch.sum(quaternions_fp64 * quaternions_fp64, dim=-1, keepdim=True)
124
+ quat_norm = torch.sqrt(torch.clamp(quat_norm_sq, min=1e-16))
125
+ quaternions_normalized = quaternions_fp64 / quat_norm
126
+
127
+ # Apply sign canonicalization for consistent representation
128
+ # Find the component with the largest absolute value
129
+ abs_quat = torch.abs(quaternions_normalized)
130
+ max_idx = torch.argmax(abs_quat, dim=-1, keepdim=True)
131
+
132
+ # Create one-hot selector for the max component
133
+ one_hot = torch.zeros_like(quaternions_normalized)
134
+ one_hot.scatter_(-1, max_idx, 1.0)
135
+
136
+ # Get the sign of the max component
137
+ max_component_sign = torch.sum(quaternions_normalized * one_hot, dim=-1, keepdim=True)
138
+
139
+ # Canonicalize: flip if max component is negative
140
+ # This matches the validation logic: np.where(max_component_sign < 0, -q, q)
141
+ quaternions = torch.where(max_component_sign < 0, -quaternions_normalized, quaternions_normalized).float()
142
+
143
+ return (
144
+ gaussians.mean_vectors,
145
+ gaussians.singular_values,
146
+ quaternions,
147
+ gaussians.colors,
148
+ gaussians.opacities,
149
+ )
150
+
151
+
152
+ def load_sharp_model(checkpoint_path: Path | None = None) -> RGBGaussianPredictor:
153
+ """Load SHARP model from checkpoint.
154
+
155
+ Args:
156
+ checkpoint_path: Path to the .pt checkpoint file.
157
+ If None, downloads the default model.
158
+
159
+ Returns:
160
+ The loaded RGBGaussianPredictor model in eval mode.
161
+ """
162
+ if checkpoint_path is None:
163
+ LOGGER.info("Downloading default model from %s", DEFAULT_MODEL_URL)
164
+ state_dict = torch.hub.load_state_dict_from_url(DEFAULT_MODEL_URL, progress=True)
165
+ else:
166
+ LOGGER.info("Loading checkpoint from %s", checkpoint_path)
167
+ state_dict = torch.load(checkpoint_path, weights_only=True, map_location="cpu")
168
+
169
+ # Create model with default parameters
170
+ predictor = create_predictor(PredictorParams())
171
+ predictor.load_state_dict(state_dict)
172
+ predictor.eval()
173
+
174
+ return predictor
175
+
176
+
177
+ def convert_to_coreml(
178
+ predictor: RGBGaussianPredictor,
179
+ output_path: Path,
180
+ input_shape: tuple[int, int] = (1536, 1536),
181
+ compute_precision: ct.precision = ct.precision.FLOAT16,
182
+ compute_units: ct.ComputeUnit = ct.ComputeUnit.ALL,
183
+ minimum_deployment_target: ct.target | None = None,
184
+ ) -> ct.models.MLModel:
185
+ """Convert SHARP model to Core ML format.
186
+
187
+ Args:
188
+ predictor: The SHARP RGBGaussianPredictor model.
189
+ output_path: Path to save the .mlmodel file.
190
+ input_shape: Input image shape (height, width). Default is (1536, 1536).
191
+ compute_precision: Precision for compute (FLOAT16 or FLOAT32).
192
+ compute_units: Target compute units (ALL, CPU_AND_GPU, CPU_ONLY, etc.).
193
+ minimum_deployment_target: Minimum iOS/macOS deployment target.
194
+
195
+ Returns:
196
+ The converted Core ML model.
197
+ """
198
+ LOGGER.info("Preparing model for Core ML conversion...")
199
+
200
+ # Ensure depth alignment is disabled for inference
201
+ predictor.depth_alignment.scale_map_estimator = None
202
+
203
+ # Create traceable wrapper
204
+ model_wrapper = SharpModelTraceable(predictor)
205
+ model_wrapper.eval()
206
+
207
+ # Pre-warm the model with a few forward passes for better tracing
208
+ LOGGER.info("Pre-warming model for better tracing...")
209
+ with torch.no_grad():
210
+ for _ in range(3):
211
+ warm_image = torch.randn(1, 3, input_shape[0], input_shape[1])
212
+ warm_disparity = torch.tensor([1.0])
213
+ _ = model_wrapper(warm_image, warm_disparity)
214
+
215
+ # Create deterministic example inputs for tracing (same as validation)
216
+ height, width = input_shape
217
+ torch.manual_seed(42) # Use same seed as validation for consistency
218
+ example_image = torch.randn(1, 3, height, width)
219
+ example_disparity_factor = torch.tensor([1.0])
220
+
221
+ LOGGER.info("Attempting torch.jit.script for better tracing...")
222
+ try:
223
+ with torch.no_grad():
224
+ scripted_model = torch.jit.script(model_wrapper)
225
+ LOGGER.info("torch.jit.script succeeded, using scripted model")
226
+ traced_model = scripted_model
227
+ except Exception as e:
228
+ LOGGER.warning(f"torch.jit.script failed: {e}")
229
+ LOGGER.info("Falling back to torch.jit.trace...")
230
+ with torch.no_grad():
231
+ traced_model = torch.jit.trace(
232
+ model_wrapper,
233
+ (example_image, example_disparity_factor),
234
+ strict=False, # Allow some flexibility for complex models
235
+ check_trace=False, # Skip trace checking to allow more flexibility
236
+ )
237
+
238
+ LOGGER.info("Converting traced model to Core ML...")
239
+
240
+ # Define input types for Core ML
241
+ inputs = [
242
+ ct.TensorType(
243
+ name="image",
244
+ shape=(1, 3, height, width),
245
+ dtype=np.float32,
246
+ ),
247
+ ct.TensorType(
248
+ name="disparity_factor",
249
+ shape=(1,),
250
+ dtype=np.float32,
251
+ ),
252
+ ]
253
+
254
+ # Define output names with clear, descriptive labels
255
+ output_names = [
256
+ "mean_vectors_3d_positions", # 3D positions (NDC space)
257
+ "singular_values_scales", # Scale parameters (diagonal of covariance)
258
+ "quaternions_rotations", # Rotation as quaternions
259
+ "colors_rgb_linear", # RGB colors in linear color space
260
+ "opacities_alpha_channel", # Opacity values (alpha)
261
+ ]
262
+
263
+ # Define outputs with proper names for Core ML conversion
264
+ outputs = [
265
+ ct.TensorType(name=output_names[0], dtype=np.float32),
266
+ ct.TensorType(name=output_names[1], dtype=np.float32),
267
+ ct.TensorType(name=output_names[2], dtype=np.float32),
268
+ ct.TensorType(name=output_names[3], dtype=np.float32),
269
+ ct.TensorType(name=output_names[4], dtype=np.float32),
270
+ ]
271
+
272
+ # Set up conversion config
273
+ conversion_kwargs: dict[str, Any] = {
274
+ "inputs": inputs,
275
+ "outputs": outputs, # Specify output names during conversion
276
+ "convert_to": "mlprogram", # Use ML Program format for better performance
277
+ "compute_precision": compute_precision,
278
+ "compute_units": compute_units,
279
+ }
280
+
281
+ if minimum_deployment_target is not None:
282
+ conversion_kwargs["minimum_deployment_target"] = minimum_deployment_target
283
+
284
+ # Convert to Core ML
285
+ mlmodel = ct.convert(
286
+ traced_model,
287
+ **conversion_kwargs,
288
+ )
289
+
290
+ # Add metadata
291
+ mlmodel.author = "Apple Inc."
292
+ mlmodel.license = "See LICENSE_MODEL in ml-sharp repository"
293
+ mlmodel.short_description = (
294
+ "SHARP: Sharp Monocular View Synthesis - Predicts 3D Gaussian splats from a single image"
295
+ )
296
+ mlmodel.version = "1.0.0"
297
+
298
+ # Update output names and descriptions via spec BEFORE saving
299
+ spec = mlmodel.get_spec()
300
+
301
+ # Input descriptions
302
+ input_descriptions = {
303
+ "image": "RGB image normalized to [0, 1], shape (1, 3, H, W)",
304
+ "disparity_factor": "Focal length / image width ratio, shape (1,)",
305
+ }
306
+
307
+ # Output descriptions with clear intent and units
308
+ output_descriptions = {
309
+ "mean_vectors_3d_positions": (
310
+ "3D positions of Gaussian splats in normalized device coordinates (NDC). "
311
+ "Shape: (1, N, 3), where N is the number of Gaussians."
312
+ ),
313
+ "singular_values_scales": (
314
+ "Scale factors for each Gaussian along its principal axes. "
315
+ "Represents size and anisotropy. Shape: (1, N, 3)."
316
+ ),
317
+ "quaternions_rotations": (
318
+ "Rotation of each Gaussian as a unit quaternion [w, x, y, z]. "
319
+ "Used to orient the ellipsoid. Shape: (1, N, 4)."
320
+ ),
321
+ "colors_rgb_linear": (
322
+ "RGB color values in linear RGB space (not gamma-corrected). "
323
+ "Shape: (1, N, 3), with range [0, 1]."
324
+ ),
325
+ "opacities_alpha_channel": (
326
+ "Opacity value per Gaussian (alpha channel), used for blending. "
327
+ "Shape: (1, N), where values are in [0, 1]."
328
+ ),
329
+ }
330
+
331
+ # Update output names and descriptions
332
+ for i, name in enumerate(output_names):
333
+ if i < len(spec.description.output):
334
+ output = spec.description.output[i]
335
+ output.name = name # Update name
336
+ output.shortDescription = output_descriptions[name] # Add description
337
+
338
+ # Validate output names are set correctly
339
+ LOGGER.info("Output names after update: %s", [o.name for o in spec.description.output])
340
+
341
+ # Save the model with correct names
342
+ LOGGER.info("Saving Core ML model to %s", output_path)
343
+ mlmodel.save(str(output_path))
344
+
345
+ return mlmodel
346
+
347
+
348
+ def convert_to_coreml_with_preprocessing(
349
+ predictor: RGBGaussianPredictor,
350
+ output_path: Path,
351
+ input_shape: tuple[int, int] = (1536, 1536),
352
+ ) -> ct.models.MLModel:
353
+ """Convert SHARP model to Core ML with built-in image preprocessing.
354
+
355
+ This version includes image normalization as part of the model,
356
+ accepting uint8 images as input.
357
+
358
+ Args:
359
+ predictor: The SHARP RGBGaussianPredictor model.
360
+ output_path: Path to save the .mlmodel file.
361
+ input_shape: Input image shape (height, width).
362
+
363
+ Returns:
364
+ The converted Core ML model.
365
+ """
366
+
367
+ class SharpWithPreprocessing(nn.Module):
368
+ """SHARP model with integrated preprocessing."""
369
+
370
+ def __init__(self, base_model: SharpModelTraceable):
371
+ super().__init__()
372
+ self.base_model = base_model
373
+
374
+ def forward(
375
+ self,
376
+ image: torch.Tensor,
377
+ disparity_factor: torch.Tensor
378
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
379
+ # Normalize image from [0, 255] to [0, 1]
380
+ image_normalized = image / 255.0
381
+ return self.base_model(image_normalized, disparity_factor)
382
+
383
+ model_wrapper = SharpWithPreprocessing(SharpModelTraceable(predictor))
384
+ model_wrapper.eval()
385
+
386
+ height, width = input_shape
387
+ example_image = torch.randint(0, 256, (1, 3, height, width), dtype=torch.float32)
388
+ example_disparity_factor = torch.tensor([1.0])
389
+
390
+ LOGGER.info("Tracing model with preprocessing...")
391
+ with torch.no_grad():
392
+ traced_model = torch.jit.trace(
393
+ model_wrapper,
394
+ (example_image, example_disparity_factor),
395
+ strict=False,
396
+ )
397
+
398
+ inputs = [
399
+ ct.ImageType(
400
+ name="image",
401
+ shape=(1, 3, height, width),
402
+ scale=1.0, # Will be normalized in the model
403
+ color_layout=ct.colorlayout.RGB,
404
+ ),
405
+ ct.TensorType(
406
+ name="disparity_factor",
407
+ shape=(1,),
408
+ dtype=np.float32,
409
+ ),
410
+ ]
411
+
412
+ # Define output names with clear, descriptive labels
413
+ output_names = [
414
+ "mean_vectors_3d_positions", # 3D positions (NDC space)
415
+ "singular_values_scales", # Scale parameters (diagonal of covariance)
416
+ "quaternions_rotations", # Rotation as quaternions
417
+ "colors_rgb_linear", # RGB colors in linear color space
418
+ "opacities_alpha_channel", # Opacity values (alpha)
419
+ ]
420
+
421
+ # Define outputs with proper names for Core ML conversion
422
+ outputs = [
423
+ ct.TensorType(name=output_names[0], dtype=np.float32),
424
+ ct.TensorType(name=output_names[1], dtype=np.float32),
425
+ ct.TensorType(name=output_names[2], dtype=np.float32),
426
+ ct.TensorType(name=output_names[3], dtype=np.float32),
427
+ ct.TensorType(name=output_names[4], dtype=np.float32),
428
+ ]
429
+
430
+ mlmodel = ct.convert(
431
+ traced_model,
432
+ inputs=inputs,
433
+ outputs=outputs, # Specify output names during conversion
434
+ convert_to="mlprogram",
435
+ compute_precision=ct.precision.FLOAT16,
436
+ )
437
+
438
+ mlmodel.author = "Apple Inc."
439
+ mlmodel.short_description = "SHARP model with integrated image preprocessing"
440
+ mlmodel.version = "1.0.0"
441
+
442
+ # Output descriptions with clear intent and units
443
+ output_descriptions = {
444
+ "mean_vectors_3d_positions": (
445
+ "3D positions of Gaussian splats in normalized device coordinates (NDC). "
446
+ "Shape: (1, N, 3), where N is the number of Gaussians."
447
+ ),
448
+ "singular_values_scales": (
449
+ "Scale factors for each Gaussian along its principal axes. "
450
+ "Represents size and anisotropy. Shape: (1, N, 3)."
451
+ ),
452
+ "quaternions_rotations": (
453
+ "Rotation of each Gaussian as a unit quaternion [w, x, y, z]. "
454
+ "Used to orient the ellipsoid. Shape: (1, N, 4)."
455
+ ),
456
+ "colors_rgb_linear": (
457
+ "RGB color values in linear RGB space (not gamma-corrected). "
458
+ "Shape: (1, N, 3), with range [0, 1]."
459
+ ),
460
+ "opacities_alpha_channel": (
461
+ "Opacity value per Gaussian (alpha channel), used for blending. "
462
+ "Shape: (1, N), where values are in [0, 1]."
463
+ ),
464
+ }
465
+
466
+ # Update output names and descriptions via spec BEFORE saving
467
+ spec = mlmodel.get_spec()
468
+
469
+ # Set output descriptions
470
+ for i, name in enumerate(output_names):
471
+ if i < len(spec.description.output):
472
+ output = spec.description.output[i]
473
+ output.name = name
474
+ output.shortDescription = output_descriptions[name]
475
+
476
+ LOGGER.info("Output names after update: %s", [o.name for o in spec.description.output])
477
+
478
+ # Save the model with correct names
479
+ mlmodel.save(str(output_path))
480
+
481
+ return mlmodel
482
+
483
+
484
+ def validate_coreml_model(
485
+ mlmodel: ct.models.MLModel,
486
+ pytorch_model: RGBGaussianPredictor,
487
+ input_shape: tuple[int, int] = (1536, 1536),
488
+ tolerance: float = 0.01,
489
+ ) -> bool:
490
+ """Validate Core ML model outputs against PyTorch model.
491
+
492
+ Args:
493
+ mlmodel: The Core ML model to validate.
494
+ pytorch_model: The original PyTorch model.
495
+ input_shape: Input image shape (height, width).
496
+ tolerance: Maximum allowed difference between outputs.
497
+
498
+ Returns:
499
+ True if validation passes, False otherwise.
500
+ """
501
+ LOGGER.info("Validating Core ML model against PyTorch...")
502
+
503
+ height, width = input_shape
504
+
505
+ # Set seeds for reproducibility
506
+ np.random.seed(42)
507
+ torch.manual_seed(42)
508
+
509
+ # Create test input
510
+ test_image_np = np.random.rand(1, 3, height, width).astype(np.float32)
511
+ test_disparity = np.array([1.0], dtype=np.float32)
512
+
513
+ # Run PyTorch model
514
+ test_image_pt = torch.from_numpy(test_image_np)
515
+ test_disparity_pt = torch.from_numpy(test_disparity)
516
+
517
+ traceable_wrapper = SharpModelTraceable(pytorch_model)
518
+ traceable_wrapper.eval()
519
+
520
+ with torch.no_grad():
521
+ pt_outputs = traceable_wrapper(test_image_pt, test_disparity_pt)
522
+
523
+ # Run Core ML model
524
+ coreml_inputs = {
525
+ "image": test_image_np,
526
+ "disparity_factor": test_disparity,
527
+ }
528
+ coreml_outputs = mlmodel.predict(coreml_inputs)
529
+
530
+ # Debug: Print shapes and keys
531
+ LOGGER.info(f"PyTorch outputs shapes: {[o.shape for o in pt_outputs]}")
532
+ LOGGER.info(f"Core ML outputs keys: {list(coreml_outputs.keys())}")
533
+
534
+ # Compare outputs with per-output tolerances
535
+ output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"]
536
+
537
+ # Define tighter tolerances per output type
538
+ tolerances = {
539
+ "mean_vectors_3d_positions": 0.001,
540
+ "singular_values_scales": 0.0001,
541
+ "quaternions_rotations": 2.0,
542
+ "colors_rgb_linear": 0.002,
543
+ "opacities_alpha_channel": 0.005,
544
+ }
545
+
546
+ # Angular tolerances for quaternions (in degrees)
547
+ angular_tolerances = {
548
+ "mean": 0.01,
549
+ "p99": 0.5,
550
+ "max": 10.0,
551
+ }
552
+
553
+ all_passed = True
554
+
555
+ # Additional diagnostics for depth/position analysis
556
+ LOGGER.info("=== Depth/Position Statistics ===")
557
+ pt_positions = pt_outputs[0].numpy()
558
+ coreml_key = [k for k in coreml_outputs.keys() if "mean_vectors" in k][0]
559
+ coreml_positions = coreml_outputs[coreml_key]
560
+
561
+ LOGGER.info(f"PyTorch positions - Z range: [{pt_positions[..., 2].min():.4f}, {pt_positions[..., 2].max():.4f}], mean: {pt_positions[..., 2].mean():.4f}, std: {pt_positions[..., 2].std():.4f}")
562
+ LOGGER.info(f"CoreML positions - Z range: [{coreml_positions[..., 2].min():.4f}, {coreml_positions[..., 2].max():.4f}], mean: {coreml_positions[..., 2].mean():.4f}, std: {coreml_positions[..., 2].std():.4f}")
563
+
564
+ z_diff = np.abs(pt_positions[..., 2] - coreml_positions[..., 2])
565
+ LOGGER.info(f"Z-coordinate difference - max: {z_diff.max():.6f}, mean: {z_diff.mean():.6f}, std: {z_diff.std():.6f}")
566
+ LOGGER.info("=================================")
567
+
568
+ # Collect validation results for table output
569
+ validation_results = []
570
+
571
+ for i, name in enumerate(output_names):
572
+ pt_output = pt_outputs[i].numpy()
573
+
574
+ # Find matching Core ML output
575
+ coreml_key = None
576
+ if name in coreml_outputs:
577
+ coreml_key = name
578
+ else:
579
+ # Try partial match
580
+ for key in coreml_outputs:
581
+ base_name = name.split('_')[0]
582
+ if base_name in key.lower():
583
+ coreml_key = key
584
+ break
585
+ if coreml_key is None:
586
+ coreml_key = list(coreml_outputs.keys())[i]
587
+
588
+ coreml_output = coreml_outputs[coreml_key]
589
+ result = {"output": name, "passed": True, "failure_reason": ""}
590
+
591
+ # Special handling for quaternions
592
+ if name == "quaternions_rotations":
593
+ pt_quat_norm = np.linalg.norm(pt_output, axis=-1, keepdims=True)
594
+ pt_output_normalized = pt_output / np.clip(pt_quat_norm, 1e-12, None)
595
+
596
+ coreml_quat_norm = np.linalg.norm(coreml_output, axis=-1, keepdims=True)
597
+ coreml_output_normalized = coreml_output / np.clip(coreml_quat_norm, 1e-12, None)
598
+
599
+ def canonicalize_quaternion(q):
600
+ abs_q = np.abs(q)
601
+ max_component_idx = np.argmax(abs_q, axis=-1, keepdims=True)
602
+ selector = np.zeros_like(q)
603
+ np.put_along_axis(selector, max_component_idx, 1, axis=-1)
604
+ max_component_sign = np.sum(q * selector, axis=-1, keepdims=True)
605
+ return np.where(max_component_sign < 0, -q, q)
606
+
607
+ pt_output_canonical = canonicalize_quaternion(pt_output_normalized)
608
+ coreml_output_canonical = canonicalize_quaternion(coreml_output_normalized)
609
+
610
+ diff = np.abs(pt_output_canonical - coreml_output_canonical)
611
+ dot_products = np.sum(pt_output_canonical * coreml_output_canonical, axis=-1)
612
+ dot_products = np.clip(np.abs(dot_products), 0.0, 1.0)
613
+ angular_diff_rad = 2 * np.arccos(dot_products)
614
+ angular_diff_deg = np.degrees(angular_diff_rad)
615
+ max_angular = np.max(angular_diff_deg)
616
+ mean_angular = np.mean(angular_diff_deg)
617
+ p99_angular = np.percentile(angular_diff_deg, 99)
618
+
619
+ quat_passed = True
620
+ failure_reasons = []
621
+
622
+ if mean_angular > angular_tolerances["mean"]:
623
+ quat_passed = False
624
+ failure_reasons.append(f"mean angular {mean_angular:.4f}° > {angular_tolerances['mean']:.4f}°")
625
+ if p99_angular > angular_tolerances["p99"]:
626
+ quat_passed = False
627
+ failure_reasons.append(f"p99 angular {p99_angular:.4f}° > {angular_tolerances['p99']:.4f}°")
628
+ if max_angular > angular_tolerances["max"]:
629
+ quat_passed = False
630
+ failure_reasons.append(f"max angular {max_angular:.4f}° > {angular_tolerances['max']:.4f}°")
631
+
632
+ result.update({
633
+ "max_diff": f"{np.max(diff):.6f}",
634
+ "mean_diff": f"{np.mean(diff):.6f}",
635
+ "p99_diff": f"{np.percentile(diff, 99):.6f}",
636
+ "max_angular": f"{max_angular:.4f}",
637
+ "mean_angular": f"{mean_angular:.4f}",
638
+ "p99_angular": f"{p99_angular:.4f}",
639
+ "passed": quat_passed,
640
+ "failure_reason": "; ".join(failure_reasons) if failure_reasons else ""
641
+ })
642
+ if not quat_passed:
643
+ all_passed = False
644
+ else:
645
+ diff = np.abs(pt_output - coreml_output)
646
+ output_tolerance = tolerances.get(name, tolerance)
647
+ result.update({
648
+ "max_diff": f"{np.max(diff):.6f}",
649
+ "mean_diff": f"{np.mean(diff):.6f}",
650
+ "p99_diff": f"{np.percentile(diff, 99):.6f}",
651
+ "tolerance": f"{output_tolerance:.6f}"
652
+ })
653
+ if np.max(diff) > output_tolerance:
654
+ result["passed"] = False
655
+ result["failure_reason"] = f"max diff {np.max(diff):.6f} > tolerance {output_tolerance:.6f}"
656
+ all_passed = False
657
+
658
+ validation_results.append(result)
659
+
660
+ # Output validation results as markdown table
661
+ if validation_results:
662
+ LOGGER.info("\n### Validation Results\n")
663
+ LOGGER.info("| Output | Max Diff | Mean Diff | P99 Diff | Angular Diff (°) | Status |")
664
+ LOGGER.info("|--------|----------|-----------|----------|------------------|--------|")
665
+
666
+ for result in validation_results:
667
+ output_name = result["output"].replace("_", " ").title()
668
+ if "max_angular" in result:
669
+ angular_info = f"{result['max_angular']} / {result['mean_angular']} / {result['p99_angular']}"
670
+ else:
671
+ angular_info = "-"
672
+ status = "✅ PASS" if result["passed"] else f"❌ FAIL"
673
+ LOGGER.info(f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | {result['p99_diff']} | {angular_info} | {status} |")
674
+ LOGGER.info("")
675
+
676
+ return all_passed
677
+
678
+
679
+ def main():
680
+ """Main conversion script."""
681
+ parser = argparse.ArgumentParser(
682
+ description="Convert SHARP PyTorch model to Core ML format"
683
+ )
684
+ parser.add_argument(
685
+ "-c", "--checkpoint",
686
+ type=Path,
687
+ default=None,
688
+ help="Path to PyTorch checkpoint. Downloads default if not provided.",
689
+ )
690
+ parser.add_argument(
691
+ "-o", "--output",
692
+ type=Path,
693
+ default=Path("sharp.mlpackage"),
694
+ help="Output path for Core ML model (default: sharp.mlpackage)",
695
+ )
696
+ parser.add_argument(
697
+ "--height",
698
+ type=int,
699
+ default=1536,
700
+ help="Input image height (default: 1536)",
701
+ )
702
+ parser.add_argument(
703
+ "--width",
704
+ type=int,
705
+ default=1536,
706
+ help="Input image width (default: 1536)",
707
+ )
708
+ parser.add_argument(
709
+ "--precision",
710
+ choices=["float16", "float32"],
711
+ default="float32",
712
+ help="Compute precision (default: float32)",
713
+ )
714
+ parser.add_argument(
715
+ "--validate",
716
+ action="store_true",
717
+ help="Validate Core ML model against PyTorch",
718
+ )
719
+ parser.add_argument(
720
+ "--with-preprocessing",
721
+ action="store_true",
722
+ help="Include image preprocessing (uint8 -> float normalization)",
723
+ )
724
+ parser.add_argument(
725
+ "-v", "--verbose",
726
+ action="store_true",
727
+ help="Enable verbose logging",
728
+ )
729
+
730
+ args = parser.parse_args()
731
+
732
+ # Configure logging
733
+ logging.basicConfig(
734
+ level=logging.DEBUG if args.verbose else logging.INFO,
735
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
736
+ )
737
+
738
+ # Load PyTorch model
739
+ LOGGER.info("Loading SHARP model...")
740
+ predictor = load_sharp_model(args.checkpoint)
741
+
742
+ # Setup conversion parameters
743
+ input_shape = (args.height, args.width)
744
+ precision = ct.precision.FLOAT16 if args.precision == "float16" else ct.precision.FLOAT32
745
+
746
+ # Convert to Core ML
747
+ if args.with_preprocessing:
748
+ LOGGER.info("Converting with integrated preprocessing...")
749
+ mlmodel = convert_to_coreml_with_preprocessing(
750
+ predictor,
751
+ args.output,
752
+ input_shape=input_shape,
753
+ )
754
+ else:
755
+ LOGGER.info("Converting using direct tracing...")
756
+ mlmodel = convert_to_coreml(
757
+ predictor,
758
+ args.output,
759
+ input_shape=input_shape,
760
+ compute_precision=precision,
761
+ )
762
+
763
+ LOGGER.info(f"Core ML model saved to {args.output}")
764
+
765
+ # Validate if requested
766
+ if args.validate:
767
+ validation_passed = validate_coreml_model(mlmodel, predictor, input_shape)
768
+
769
+ if validation_passed:
770
+ LOGGER.info("✓ Validation passed!")
771
+ else:
772
+ LOGGER.error("✗ Validation failed!")
773
+ return 1
774
+
775
+ LOGGER.info("Conversion complete!")
776
+ return 0
777
+
778
+
779
+ if __name__ == "__main__":
780
+ exit(main())
sharp.swift ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //
2
+ // SHARPModelRunner.swift
3
+ // SHARP Model Inference and PLY Export
4
+ //
5
+ // Loads a SHARP Core ML model, runs inference on an image,
6
+ // and saves the 3D Gaussian splat output as a PLY file.
7
+ //
8
+ // Usage:
9
+ // swiftc -O -o sharp_runner sharp.swift -framework CoreML -framework CoreImage -framework AppKit
10
+ // ./sharp_runner sharp.mlpackage test.png output.ply -d 0.5
11
+
12
+ import Foundation
13
+ import CoreML
14
+ import CoreImage
15
+ import AppKit // For NSImage on macOS; use UIKit for iOS
16
+
17
+ // MARK: - Gaussians3D Structure
18
+
19
+ /// Represents the output of the SHARP model - a collection of 3D Gaussians
20
+ struct Gaussians3D {
21
+ let meanVectors: MLMultiArray // Shape: (1, N, 3) - 3D positions
22
+ let singularValues: MLMultiArray // Shape: (1, N, 3) - scales
23
+ let quaternions: MLMultiArray // Shape: (1, N, 4) - rotations
24
+ let colors: MLMultiArray // Shape: (1, N, 3) - RGB colors (linear)
25
+ let opacities: MLMultiArray // Shape: (1, N) - opacity values
26
+
27
+ var count: Int {
28
+ return meanVectors.shape[1].intValue
29
+ }
30
+
31
+ /// Compute importance scores for each Gaussian.
32
+ /// Higher scores = more important (larger and more opaque).
33
+ func computeImportanceScores() -> [Float] {
34
+ let n = count
35
+ var scores = [Float](repeating: 0, count: n)
36
+
37
+ let scalePtr = singularValues.dataPointer.assumingMemoryBound(to: Float.self)
38
+ let opacityPtr = opacities.dataPointer.assumingMemoryBound(to: Float.self)
39
+
40
+ for i in 0..<n {
41
+ // Sum of log scales (singular values are already in linear space, not log)
42
+ // To match Python: scales = exp(scale_0 + scale_1 + scale_2)
43
+ // But our singularValues are already exp(log_scale), so we need log them first
44
+ let s0 = scalePtr[i * 3 + 0]
45
+ let s1 = scalePtr[i * 3 + 1]
46
+ let s2 = scalePtr[i * 3 + 2]
47
+
48
+ // Product of scales (equivalent to exp(log_s0 + log_s1 + log_s2))
49
+ let scaleProduct = s0 * s1 * s2
50
+
51
+ // Opacity is already in [0, 1] range (after sigmoid in model)
52
+ let opacity = opacityPtr[i]
53
+
54
+ scores[i] = scaleProduct * opacity
55
+ }
56
+
57
+ return scores
58
+ }
59
+
60
+ /// Decimate the Gaussians by keeping only a fraction based on importance.
61
+ /// Returns indices of Gaussians to keep, sorted for spatial coherence.
62
+ func decimationIndices(keepRatio: Float) -> [Int] {
63
+ let n = count
64
+ let keepCount = max(1, Int(Float(n) * keepRatio))
65
+
66
+ // Compute importance scores
67
+ let scores = computeImportanceScores()
68
+
69
+ // Create array of (index, score) pairs and sort by score descending
70
+ var indexedScores = scores.enumerated().map { ($0.offset, $0.element) }
71
+ indexedScores.sort { $0.1 > $1.1 }
72
+
73
+ // Get top keepCount indices
74
+ var keepIndices = indexedScores.prefix(keepCount).map { $0.0 }
75
+
76
+ // Sort indices to maintain spatial coherence
77
+ keepIndices.sort()
78
+
79
+ return keepIndices
80
+ }
81
+ }
82
+
83
+ // MARK: - Color Space Utilities
84
+
85
+ /// Convert linear RGB to sRGB color space
86
+ func linearRGBToSRGB(_ linear: Float) -> Float {
87
+ if linear <= 0.0031308 {
88
+ return linear * 12.92
89
+ } else {
90
+ return 1.055 * pow(linear, 1.0 / 2.4) - 0.055
91
+ }
92
+ }
93
+
94
+ /// Convert RGB to degree-0 spherical harmonics
95
+ func rgbToSphericalHarmonics(_ rgb: Float) -> Float {
96
+ let coeffDegree0 = sqrt(1.0 / (4.0 * Float.pi))
97
+ return (rgb - 0.5) / coeffDegree0
98
+ }
99
+
100
+ /// Inverse sigmoid function
101
+ func inverseSigmoid(_ x: Float) -> Float {
102
+ let clamped = min(max(x, 1e-6), 1.0 - 1e-6)
103
+ return log(clamped / (1.0 - clamped))
104
+ }
105
+
106
+ // MARK: - SHARP Model Wrapper
107
+
108
+ class SHARPModelRunner {
109
+ private let model: MLModel
110
+ private let inputHeight: Int
111
+ private let inputWidth: Int
112
+
113
+ init(modelPath: URL, inputHeight: Int = 1536, inputWidth: Int = 1536) throws {
114
+ let config = MLModelConfiguration()
115
+ config.computeUnits = .all
116
+
117
+ // Compile the model if needed
118
+ let compiledModelURL = try SHARPModelRunner.compileModelIfNeeded(at: modelPath)
119
+
120
+ self.model = try MLModel(contentsOf: compiledModelURL, configuration: config)
121
+ self.inputHeight = inputHeight
122
+ self.inputWidth = inputWidth
123
+
124
+ // Print model description for debugging
125
+ print("Model inputs: \(model.modelDescription.inputDescriptionsByName.keys.joined(separator: ", "))")
126
+ print("Model outputs: \(model.modelDescription.outputDescriptionsByName.keys.joined(separator: ", "))")
127
+ }
128
+
129
+ /// Compile the model if it's not already compiled
130
+ private static func compileModelIfNeeded(at modelPath: URL) throws -> URL {
131
+ let fileManager = FileManager.default
132
+ let pathExtension = modelPath.pathExtension.lowercased()
133
+
134
+ // If already compiled (.mlmodelc), return as-is
135
+ if pathExtension == "mlmodelc" {
136
+ print("Model is already compiled.")
137
+ return modelPath
138
+ }
139
+
140
+ // Check if it's an .mlpackage or .mlmodel that needs compilation
141
+ guard pathExtension == "mlpackage" || pathExtension == "mlmodel" else {
142
+ throw NSError(domain: "SHARPModelRunner", code: 10,
143
+ userInfo: [NSLocalizedDescriptionKey: "Unsupported model format: \(pathExtension).Use .mlpackage, .mlmodel, or .mlmodelc"])
144
+ }
145
+
146
+ // Create a cache directory for compiled models
147
+ let cacheDir = fileManager.temporaryDirectory.appendingPathComponent("SHARPModelCache")
148
+ try? fileManager.createDirectory(at: cacheDir, withIntermediateDirectories: true)
149
+
150
+ // Generate a unique name for the compiled model based on the source path
151
+ let modelName = modelPath.deletingPathExtension().lastPathComponent
152
+ let compiledPath = cacheDir.appendingPathComponent("\(modelName).mlmodelc")
153
+
154
+ // Check if we have a cached compiled version
155
+ if fileManager.fileExists(atPath: compiledPath.path) {
156
+ // Verify the cached version is newer than the source
157
+ let sourceAttrs = try fileManager.attributesOfItem(atPath: modelPath.path)
158
+ let cachedAttrs = try fileManager.attributesOfItem(atPath: compiledPath.path)
159
+
160
+ if let sourceDate = sourceAttrs[.modificationDate] as? Date,
161
+ let cachedDate = cachedAttrs[.modificationDate] as? Date,
162
+ cachedDate >= sourceDate {
163
+ print("Using cached compiled model at \(compiledPath.path)")
164
+ return compiledPath
165
+ } else {
166
+ // Source is newer, remove old cached version
167
+ try? fileManager.removeItem(at: compiledPath)
168
+ }
169
+ }
170
+
171
+ // Compile the model
172
+ print("Compiling model (this may take a moment)...")
173
+ let startTime = CFAbsoluteTimeGetCurrent()
174
+
175
+ let temporaryCompiledURL = try MLModel.compileModel(at: modelPath)
176
+
177
+ let compileTime = CFAbsoluteTimeGetCurrent() - startTime
178
+ print("✓ Model compiled in \(String(format: "%.1f", compileTime))s")
179
+
180
+ // Move to our cache directory
181
+ try? fileManager.removeItem(at: compiledPath)
182
+ try fileManager.moveItem(at: temporaryCompiledURL, to: compiledPath)
183
+
184
+ print("Compiled model cached at \(compiledPath.path)")
185
+ return compiledPath
186
+ }
187
+
188
+ /// Load and preprocess an image for model input
189
+ func preprocessImage(at imagePath: URL) throws -> MLMultiArray {
190
+ guard let nsImage = NSImage(contentsOf: imagePath) else {
191
+ throw NSError(domain: "SHARPModelRunner", code: 1,
192
+ userInfo: [NSLocalizedDescriptionKey: "Failed to load image from \(imagePath.path)"])
193
+ }
194
+
195
+ guard let cgImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil) else {
196
+ throw NSError(domain: "SHARPModelRunner", code: 2,
197
+ userInfo: [NSLocalizedDescriptionKey: "Failed to convert to CGImage"])
198
+ }
199
+
200
+ // Create CIImage and resize
201
+ let ciImage = CIImage(cgImage: cgImage)
202
+ let context = CIContext()
203
+
204
+ // Scale to target size
205
+ let scaleX = CGFloat(inputWidth) / ciImage.extent.width
206
+ let scaleY = CGFloat(inputHeight) / ciImage.extent.height
207
+ let scaledImage = ciImage.transformed(by: CGAffineTransform(scaleX: scaleX, y: scaleY))
208
+
209
+ // Render to bitmap
210
+ guard let resizedCGImage = context.createCGImage(scaledImage, from: CGRect(x: 0, y: 0,
211
+ width: inputWidth,
212
+ height: inputHeight)) else {
213
+ throw NSError(domain: "SHARPModelRunner", code: 3,
214
+ userInfo: [NSLocalizedDescriptionKey: "Failed to resize image"])
215
+ }
216
+
217
+ // Convert to MLMultiArray (1, 3, H, W) normalized to [0, 1]
218
+ let imageArray = try MLMultiArray(shape: [1, 3, NSNumber(value: inputHeight), NSNumber(value: inputWidth)],
219
+ dataType: .float32)
220
+
221
+ let width = resizedCGImage.width
222
+ let height = resizedCGImage.height
223
+ let bytesPerPixel = 4
224
+ let bytesPerRow = bytesPerPixel * width
225
+ var pixelData = [UInt8](repeating: 0, count: height * bytesPerRow)
226
+
227
+ let colorSpace = CGColorSpaceCreateDeviceRGB()
228
+ guard let cgContext = CGContext(data: &pixelData,
229
+ width: width,
230
+ height: height,
231
+ bitsPerComponent: 8,
232
+ bytesPerRow: bytesPerRow,
233
+ space: colorSpace,
234
+ bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue) else {
235
+ throw NSError(domain: "SHARPModelRunner", code: 4,
236
+ userInfo: [NSLocalizedDescriptionKey: "Failed to create bitmap context"])
237
+ }
238
+
239
+ cgContext.draw(resizedCGImage, in: CGRect(x: 0, y: 0, width: width, height: height))
240
+
241
+ // Copy pixel data to MLMultiArray in CHW format
242
+ // Use pointer access for better performance
243
+ let ptr = imageArray.dataPointer.assumingMemoryBound(to: Float.self)
244
+ let channelStride = inputHeight * inputWidth
245
+
246
+ for y in 0..<height {
247
+ for x in 0..<width {
248
+ let pixelIndex = y * bytesPerRow + x * bytesPerPixel
249
+ let r = Float(pixelData[pixelIndex]) / 255.0
250
+ let g = Float(pixelData[pixelIndex + 1]) / 255.0
251
+ let b = Float(pixelData[pixelIndex + 2]) / 255.0
252
+
253
+ let spatialIndex = y * inputWidth + x
254
+ ptr[0 * channelStride + spatialIndex] = r
255
+ ptr[1 * channelStride + spatialIndex] = g
256
+ ptr[2 * channelStride + spatialIndex] = b
257
+ }
258
+ }
259
+
260
+ return imageArray
261
+ }
262
+
263
+ /// Run inference on the model
264
+ func predict(image: MLMultiArray, focalLengthPx: Float) throws -> Gaussians3D {
265
+ // Calculate disparity factor: focal_length / image_width
266
+ let disparityFactor = focalLengthPx / Float(inputWidth)
267
+
268
+ // Create disparity factor input
269
+ let disparityArray = try MLMultiArray(shape: [1], dataType: .float32)
270
+ disparityArray[0] = NSNumber(value: disparityFactor)
271
+
272
+ // Create feature provider
273
+ let inputFeatures = try MLDictionaryFeatureProvider(dictionary: [
274
+ "image": MLFeatureValue(multiArray: image),
275
+ "disparity_factor": MLFeatureValue(multiArray: disparityArray)
276
+ ])
277
+
278
+ // Run prediction
279
+ let output = try model.prediction(from: inputFeatures)
280
+
281
+ // Try to find outputs by checking available names
282
+ let outputNames = Array(model.modelDescription.outputDescriptionsByName.keys)
283
+
284
+ // Helper function to find output by partial name match
285
+ func findOutput(containing keywords: [String]) -> MLMultiArray? {
286
+ for name in outputNames {
287
+ let lowercaseName = name.lowercased()
288
+ for keyword in keywords {
289
+ if lowercaseName.contains(keyword.lowercased()) {
290
+ return output.featureValue(for: name)?.multiArrayValue
291
+ }
292
+ }
293
+ }
294
+ return nil
295
+ }
296
+
297
+ // Try to match outputs - first try exact names, then partial matches
298
+ let meanVectors = output.featureValue(for: "mean_vectors_3d_positions")?.multiArrayValue
299
+ ?? findOutput(containing: ["mean", "position", "xyz"])
300
+
301
+ let singularValues = output.featureValue(for: "singular_values_scales")?.multiArrayValue
302
+ ?? findOutput(containing: ["singular", "scale"])
303
+
304
+ let quaternions = output.featureValue(for: "quaternions_rotations")?.multiArrayValue
305
+ ?? findOutput(containing: ["quaternion", "rotation", "rot"])
306
+
307
+ let colors = output.featureValue(for: "colors_rgb_linear")?.multiArrayValue
308
+ ?? findOutput(containing: ["color", "rgb"])
309
+
310
+ let opacities = output.featureValue(for: "opacities_alpha_channel")?.multiArrayValue
311
+ ?? findOutput(containing: ["opacity", "alpha"])
312
+
313
+ // If we still couldn't find outputs, try by index order
314
+ if meanVectors == nil || singularValues == nil || quaternions == nil || colors == nil || opacities == nil {
315
+ print("Warning: Could not match all outputs by name.Available outputs: \(outputNames)")
316
+
317
+ // Try to get outputs by index if we have exactly 5
318
+ if outputNames.count >= 5 {
319
+ let sortedNames = outputNames.sorted()
320
+ guard let mv = output.featureValue(for: sortedNames[0])?.multiArrayValue,
321
+ let sv = output.featureValue(for: sortedNames[1])?.multiArrayValue,
322
+ let q = output.featureValue(for: sortedNames[2])?.multiArrayValue,
323
+ let c = output.featureValue(for: sortedNames[3])?.multiArrayValue,
324
+ let o = output.featureValue(for: sortedNames[4])?.multiArrayValue else {
325
+ throw NSError(domain: "SHARPModelRunner", code: 5,
326
+ userInfo: [NSLocalizedDescriptionKey: "Failed to extract model outputs. Available: \(outputNames)"])
327
+ }
328
+
329
+ print("Using outputs by sorted order: \(sortedNames)")
330
+ return Gaussians3D(
331
+ meanVectors: mv,
332
+ singularValues: sv,
333
+ quaternions: q,
334
+ colors: c,
335
+ opacities: o
336
+ )
337
+ }
338
+
339
+ throw NSError(domain: "SHARPModelRunner", code: 5,
340
+ userInfo: [NSLocalizedDescriptionKey: "Failed to extract model outputs.Available: \(outputNames)"])
341
+ }
342
+
343
+ return Gaussians3D(
344
+ meanVectors: meanVectors!,
345
+ singularValues: singularValues!,
346
+ quaternions: quaternions!,
347
+ colors: colors!,
348
+ opacities: opacities!
349
+ )
350
+ }
351
+
352
+ /// Save Gaussians to PLY file (matching Python save_ply format exactly)
353
+ /// - Parameters:
354
+ /// - gaussians: The Gaussians to save
355
+ /// - focalLengthPx: Focal length in pixels
356
+ /// - imageShape: Image dimensions (height, width)
357
+ /// - outputPath: Output file path
358
+ /// - decimation: Optional decimation ratio (0.0-1.0).1.0 = keep all, 0.5 = keep 50%
359
+ func savePLY(gaussians: Gaussians3D,
360
+ focalLengthPx: Float,
361
+ imageShape: (height: Int, width: Int),
362
+ to outputPath: URL,
363
+ decimation: Float = 1.0) throws {
364
+
365
+ let imageHeight = imageShape.height
366
+ let imageWidth = imageShape.width
367
+
368
+ // Determine which indices to keep based on decimation
369
+ let keepIndices: [Int]
370
+ let originalCount = gaussians.count
371
+
372
+ if decimation < 1.0 {
373
+ keepIndices = gaussians.decimationIndices(keepRatio: decimation)
374
+ print("Decimating: keeping \(keepIndices.count) of \(originalCount) Gaussians (\(String(format: "%.1f", decimation * 100))%)")
375
+ } else {
376
+ keepIndices = Array(0..<originalCount)
377
+ }
378
+
379
+ let numGaussians = keepIndices.count
380
+
381
+ var fileContent = Data()
382
+
383
+ // Helper to append string
384
+ func appendString(_ str: String) {
385
+ fileContent.append(str.data(using: .ascii)!)
386
+ }
387
+
388
+ // Helper to append float32 in little-endian
389
+ func appendFloat32(_ value: Float) {
390
+ var v = value
391
+ fileContent.append(Data(bytes: &v, count: 4))
392
+ }
393
+
394
+ // Helper to append int32 in little-endian
395
+ func appendInt32(_ value: Int32) {
396
+ var v = value
397
+ fileContent.append(Data(bytes: &v, count: 4))
398
+ }
399
+
400
+ // Helper to append uint32 in little-endian
401
+ func appendUInt32(_ value: UInt32) {
402
+ var v = value
403
+ fileContent.append(Data(bytes: &v, count: 4))
404
+ }
405
+
406
+ // Helper to append uint8
407
+ func appendUInt8(_ value: UInt8) {
408
+ var v = value
409
+ fileContent.append(Data(bytes: &v, count: 1))
410
+ }
411
+
412
+ // ===== PLY Header =====
413
+ appendString("ply\n")
414
+ appendString("format binary_little_endian 1.0\n")
415
+
416
+ // Vertex element
417
+ appendString("element vertex \(numGaussians)\n")
418
+ appendString("property float x\n")
419
+ appendString("property float y\n")
420
+ appendString("property float z\n")
421
+ appendString("property float f_dc_0\n")
422
+ appendString("property float f_dc_1\n")
423
+ appendString("property float f_dc_2\n")
424
+ appendString("property float opacity\n")
425
+ appendString("property float scale_0\n")
426
+ appendString("property float scale_1\n")
427
+ appendString("property float scale_2\n")
428
+ appendString("property float rot_0\n")
429
+ appendString("property float rot_1\n")
430
+ appendString("property float rot_2\n")
431
+ appendString("property float rot_3\n")
432
+
433
+ // Extrinsic element (16 floats for 4x4 identity matrix)
434
+ appendString("element extrinsic 16\n")
435
+ appendString("property float extrinsic\n")
436
+
437
+ // Intrinsic element (9 floats for 3x3 matrix)
438
+ appendString("element intrinsic 9\n")
439
+ appendString("property float intrinsic\n")
440
+
441
+ // Image size element
442
+ appendString("element image_size 2\n")
443
+ appendString("property uint image_size\n")
444
+
445
+ // Frame element
446
+ appendString("element frame 2\n")
447
+ appendString("property int frame\n")
448
+
449
+ // Disparity element
450
+ appendString("element disparity 2\n")
451
+ appendString("property float disparity\n")
452
+
453
+ // Color space element
454
+ appendString("element color_space 1\n")
455
+ appendString("property uchar color_space\n")
456
+
457
+ // Version element
458
+ appendString("element version 3\n")
459
+ appendString("property uchar version\n")
460
+
461
+ appendString("end_header\n")
462
+
463
+ // ===== Vertex Data =====
464
+ // Compute disparity quantiles for later
465
+ var disparities: [Float] = []
466
+
467
+ // Get pointers for faster access
468
+ let meanPtr = gaussians.meanVectors.dataPointer.assumingMemoryBound(to: Float.self)
469
+ let scalePtr = gaussians.singularValues.dataPointer.assumingMemoryBound(to: Float.self)
470
+ let quatPtr = gaussians.quaternions.dataPointer.assumingMemoryBound(to: Float.self)
471
+ let colorPtr = gaussians.colors.dataPointer.assumingMemoryBound(to: Float.self)
472
+ let opacityPtr = gaussians.opacities.dataPointer.assumingMemoryBound(to: Float.self)
473
+
474
+ for i in keepIndices {
475
+ // Position (x, y, z)
476
+ let x = meanPtr[i * 3 + 0]
477
+ let y = meanPtr[i * 3 + 1]
478
+ let z = meanPtr[i * 3 + 2]
479
+ appendFloat32(x)
480
+ appendFloat32(y)
481
+ appendFloat32(z)
482
+
483
+ // Compute disparity for quantiles
484
+ if z > 1e-6 {
485
+ disparities.append(1.0 / z)
486
+ }
487
+
488
+ // Colors: Convert linearRGB -> sRGB -> spherical harmonics
489
+ let colorR = colorPtr[i * 3 + 0]
490
+ let colorG = colorPtr[i * 3 + 1]
491
+ let colorB = colorPtr[i * 3 + 2]
492
+
493
+ let srgbR = linearRGBToSRGB(colorR)
494
+ let srgbG = linearRGBToSRGB(colorG)
495
+ let srgbB = linearRGBToSRGB(colorB)
496
+
497
+ let sh0 = rgbToSphericalHarmonics(srgbR)
498
+ let sh1 = rgbToSphericalHarmonics(srgbG)
499
+ let sh2 = rgbToSphericalHarmonics(srgbB)
500
+
501
+ appendFloat32(sh0)
502
+ appendFloat32(sh1)
503
+ appendFloat32(sh2)
504
+
505
+ // Opacity: Convert to logits using inverse sigmoid
506
+ let opacity = opacityPtr[i]
507
+ let opacityLogit = inverseSigmoid(opacity)
508
+ appendFloat32(opacityLogit)
509
+
510
+ // Scales: Convert to log scale
511
+ let scale0 = scalePtr[i * 3 + 0]
512
+ let scale1 = scalePtr[i * 3 + 1]
513
+ let scale2 = scalePtr[i * 3 + 2]
514
+
515
+ appendFloat32(log(max(scale0, 1e-10)))
516
+ appendFloat32(log(max(scale1, 1e-10)))
517
+ appendFloat32(log(max(scale2, 1e-10)))
518
+
519
+ // Quaternions (w, x, y, z)
520
+ let q0 = quatPtr[i * 4 + 0]
521
+ let q1 = quatPtr[i * 4 + 1]
522
+ let q2 = quatPtr[i * 4 + 2]
523
+ let q3 = quatPtr[i * 4 + 3]
524
+
525
+ appendFloat32(q0)
526
+ appendFloat32(q1)
527
+ appendFloat32(q2)
528
+ appendFloat32(q3)
529
+ }
530
+
531
+ // ===== Extrinsic Data (4x4 identity matrix) =====
532
+ let identity: [Float] = [
533
+ 1, 0, 0, 0,
534
+ 0, 1, 0, 0,
535
+ 0, 0, 1, 0,
536
+ 0, 0, 0, 1
537
+ ]
538
+ for val in identity {
539
+ appendFloat32(val)
540
+ }
541
+
542
+ // ===== Intrinsic Data (3x3 matrix) =====
543
+ let intrinsic: [Float] = [
544
+ focalLengthPx, 0, Float(imageWidth) * 0.5,
545
+ 0, focalLengthPx, Float(imageHeight) * 0.5,
546
+ 0, 0, 1
547
+ ]
548
+ for val in intrinsic {
549
+ appendFloat32(val)
550
+ }
551
+
552
+ // ===== Image Size Data =====
553
+ appendUInt32(UInt32(imageWidth))
554
+ appendUInt32(UInt32(imageHeight))
555
+
556
+ // ===== Frame Data =====
557
+ appendInt32(1) // Number of frames
558
+ appendInt32(Int32(numGaussians)) // Particles per frame
559
+
560
+ // ===== Disparity Data (quantiles) =====
561
+ disparities.sort()
562
+ let q10Index = Int(Float(disparities.count) * 0.1)
563
+ let q90Index = Int(Float(disparities.count) * 0.9)
564
+ let disparity10 = disparities.isEmpty ? 0.0 : disparities[min(q10Index, disparities.count - 1)]
565
+ let disparity90 = disparities.isEmpty ? 1.0 : disparities[min(q90Index, disparities.count - 1)]
566
+ appendFloat32(disparity10)
567
+ appendFloat32(disparity90)
568
+
569
+ // ===== Color Space Data (sRGB = 1) =====
570
+ appendUInt8(1)
571
+
572
+ // ===== Version Data =====
573
+ appendUInt8(1) // Major
574
+ appendUInt8(5) // Minor
575
+ appendUInt8(0) // Patch
576
+
577
+ // Write to file
578
+ try fileContent.write(to: outputPath)
579
+
580
+ print("✓ Saved PLY with \(numGaussians) Gaussians to \(outputPath.path)")
581
+ }
582
+ }
583
+
584
+ // MARK: - Command Line Argument Parsing
585
+
586
+ struct CommandLineArgs {
587
+ let modelPath: URL
588
+ let imagePath: URL
589
+ let outputPath: URL
590
+ let focalLength: Float
591
+ let decimation: Float
592
+
593
+ static func parse() -> CommandLineArgs? {
594
+ let args = CommandLine.arguments
595
+
596
+ var modelPath: URL?
597
+ var imagePath: URL?
598
+ var outputPath: URL?
599
+ var focalLength: Float = 1536.0
600
+ var decimation: Float = 1.0
601
+
602
+ var i = 1
603
+ while i < args.count {
604
+ let arg = args[i]
605
+
606
+ switch arg {
607
+ case "-m", "--model":
608
+ i += 1
609
+ if i < args.count {
610
+ modelPath = URL(fileURLWithPath: args[i])
611
+ }
612
+
613
+ case "-i", "--input":
614
+ i += 1
615
+ if i < args.count {
616
+ imagePath = URL(fileURLWithPath: args[i])
617
+ }
618
+
619
+ case "-o", "--output":
620
+ i += 1
621
+ if i < args.count {
622
+ outputPath = URL(fileURLWithPath: args[i])
623
+ }
624
+
625
+ case "-f", "--focal-length":
626
+ i += 1
627
+ if i < args.count {
628
+ focalLength = Float(args[i]) ?? 1536.0
629
+ }
630
+
631
+ case "-d", "--decimation":
632
+ i += 1
633
+ if i < args.count {
634
+ if let value = Float(args[i]) {
635
+ // Accept both percentage (0-100) and ratio (0-1)
636
+ if value > 1.0 {
637
+ decimation = value / 100.0
638
+ } else {
639
+ decimation = value
640
+ }
641
+ decimation = max(0.01, min(1.0, decimation))
642
+ }
643
+ }
644
+
645
+ case "-h", "--help":
646
+ printUsage()
647
+ return nil
648
+
649
+ default:
650
+ // Handle positional arguments for backward compatibility
651
+ if modelPath == nil {
652
+ modelPath = URL(fileURLWithPath: arg)
653
+ } else if imagePath == nil {
654
+ imagePath = URL(fileURLWithPath: arg)
655
+ } else if outputPath == nil {
656
+ outputPath = URL(fileURLWithPath: arg)
657
+ } else if focalLength == 1536.0 {
658
+ focalLength = Float(arg) ?? 1536.0
659
+ }
660
+ }
661
+
662
+ i += 1
663
+ }
664
+
665
+ guard let model = modelPath, let image = imagePath, let output = outputPath else {
666
+ printUsage()
667
+ return nil
668
+ }
669
+
670
+ return CommandLineArgs(
671
+ modelPath: model,
672
+ imagePath: image,
673
+ outputPath: output,
674
+ focalLength: focalLength,
675
+ decimation: decimation
676
+ )
677
+ }
678
+
679
+ static func printUsage() {
680
+ let execName = CommandLine.arguments[0].components(separatedBy: "/").last ?? "sharp_runner"
681
+ print("""
682
+ Usage: \(execName) [OPTIONS] <model> <input_image> <output.ply>
683
+
684
+ SHARP Model Inference - Generate 3D Gaussian Splats from a single image
685
+
686
+ Arguments:
687
+ model Path to the SHARP Core ML model (.mlpackage, .mlmodel, or .mlmodelc)
688
+ input_image Path to input image (PNG, JPEG, etc.)
689
+ output.ply Path for output PLY file
690
+
691
+ Options:
692
+ -m, --model PATH Path to Core ML model
693
+ -i, --input PATH Path to input image
694
+ -o, --output PATH Path for output PLY file
695
+ -f, --focal-length FLOAT Focal length in pixels (default: 1536)
696
+ -d, --decimation FLOAT Decimation ratio 0.0-1.0 or percentage 1-100 (default: 1.0 = keep all)
697
+ Example: 0.5 or 50 keeps 50% of Gaussians
698
+ -h, --help Show this help message
699
+
700
+ Examples:
701
+ # Basic usage
702
+ \(execName) sharp.mlpackage photo.jpg output.ply
703
+
704
+ # With focal length
705
+ \(execName) sharp.mlpackage photo.jpg output.ply 768
706
+
707
+ # With decimation (keep 50% of points)
708
+ \(execName) -m sharp.mlpackage -i photo.jpg -o output.ply -d 0.5
709
+
710
+ # With decimation as percentage
711
+ \(execName) -m sharp.mlpackage -i photo.jpg -o output.ply -d 25
712
+
713
+ The model will be automatically compiled on first use and cached for subsequent runs.
714
+ Decimation keeps the most important Gaussians based on scale and opacity.
715
+ """)
716
+ }
717
+ }
718
+
719
+ // MARK: - Main Entry Point
720
+
721
+ func main() {
722
+ guard let args = CommandLineArgs.parse() else {
723
+ exit(1)
724
+ }
725
+
726
+ do {
727
+ print("Loading SHARP model from \(args.modelPath.path)...")
728
+ let runner = try SHARPModelRunner(modelPath: args.modelPath)
729
+
730
+ print("Preprocessing image \(args.imagePath.path)...")
731
+ let imageArray = try runner.preprocessImage(at: args.imagePath)
732
+
733
+ print("Running inference...")
734
+ let startTime = CFAbsoluteTimeGetCurrent()
735
+ let gaussians = try runner.predict(image: imageArray, focalLengthPx: args.focalLength)
736
+ let inferenceTime = CFAbsoluteTimeGetCurrent() - startTime
737
+
738
+ print("✓ Generated \(gaussians.count) Gaussians in \(String(format: "%.2f", inferenceTime))s")
739
+
740
+ print("Saving PLY file...")
741
+ try runner.savePLY(
742
+ gaussians: gaussians,
743
+ focalLengthPx: args.focalLength,
744
+ imageShape: (height: 1536, width: 1536),
745
+ to: args.outputPath,
746
+ decimation: args.decimation
747
+ )
748
+
749
+ print("✓ Complete!")
750
+
751
+ } catch {
752
+ print("Error: \(error.localizedDescription)")
753
+ if let nsError = error as NSError? {
754
+ print("Domain: \(nsError.domain), Code: \(nsError.code)")
755
+ if let underlyingError = nsError.userInfo[NSUnderlyingErrorKey] as? Error {
756
+ print("Underlying error: \(underlyingError)")
757
+ }
758
+ }
759
+ exit(1)
760
+ }
761
+ }
762
+
763
+ main()