File size: 28,910 Bytes
5fb2d50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
"""Convert SHARP PyTorch model to Core ML .mlmodel format.

This script converts the SHARP (Sharp Monocular View Synthesis) model
from PyTorch (.pt) to Core ML (.mlmodel) format for deployment on Apple devices.
"""

from __future__ import annotations

import argparse
import logging
from pathlib import Path
from typing import Any

import coremltools as ct
import numpy as np
import torch
import torch.nn as nn

# Import SHARP model components
from sharp.models import PredictorParams, create_predictor
from sharp.models.predictor import RGBGaussianPredictor

LOGGER = logging.getLogger(__name__)

DEFAULT_MODEL_URL = "https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt"


class SafeClamp(nn.Module):
    """Safe clamp operation that avoids tracing issues."""

    def forward(self, x, min_val=1e-4, max_val=1e4):
        return torch.clamp(x, min=min_val, max=max_val)


class SafeDivision(nn.Module):
    """Safe division that avoids division by zero."""

    def forward(self, numerator, denominator):
        return numerator / torch.clamp(denominator, min=1e-8)


class SharpModelTraceable(nn.Module):
    """Fully traceable version of SHARP for Core ML conversion.

    This version removes all dynamic control flow and makes the model
    fully traceable with torch.jit.trace.
    """

    def __init__(self, predictor: RGBGaussianPredictor):
        """Initialize the traceable wrapper.

        Args:
            predictor: The SHARP RGBGaussianPredictor model.
        """
        super().__init__()
        # Copy all submodules
        self.init_model = predictor.init_model
        self.feature_model = predictor.feature_model
        self.monodepth_model = predictor.monodepth_model
        self.prediction_head = predictor.prediction_head
        self.gaussian_composer = predictor.gaussian_composer
        self.depth_alignment = predictor.depth_alignment

        # Replace problematic operations with custom modules
        self.safe_clamp = SafeClamp()
        self.safe_div = SafeDivision()

    def forward(
        self,
        image: torch.Tensor,
        disparity_factor: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Run inference with traceable forward pass.

        Args:
            image: Input image tensor of shape (1, 3, H, W) in range [0, 1].
            disparity_factor: Disparity factor tensor of shape (1,).

        Returns:
            Tuple of 5 tensors representing 3D Gaussians.
        """
        # Estimate depth using monodepth
        monodepth_output = self.monodepth_model(image)
        monodepth_disparity = monodepth_output.disparity

        # Convert disparity to depth with higher precision
        # Use tighter clamp bounds and higher precision intermediate computation
        disparity_factor_expanded = disparity_factor[:, None, None, None]

        # Cast to float64 for more precise division, then back to float32
        disparity_clamped = monodepth_disparity.clamp(min=1e-6, max=1e4)
        monodepth = disparity_factor_expanded.double() / disparity_clamped.double()
        monodepth = monodepth.float()

        # Apply depth alignment (inference mode)
        monodepth, _ = self.depth_alignment(monodepth, None, monodepth_output.decoder_features)

        # Initialize gaussians
        init_output = self.init_model(image, monodepth)

        # Extract features
        image_features = self.feature_model(
            init_output.feature_input,
            encodings=monodepth_output.output_features
        )

        # Predict deltas
        delta_values = self.prediction_head(image_features)

        # Compose final gaussians
        gaussians = self.gaussian_composer(
            delta=delta_values,
            base_values=init_output.gaussian_base_values,
            global_scale=init_output.global_scale,
        )

        # Normalize quaternions for consistent validation and inference
        # This is critical for CoreML conversion accuracy
        quaternions = gaussians.quaternions

        # Use double precision for quaternion normalization to reduce numerical errors
        quaternions_fp64 = quaternions.double()
        quat_norm_sq = torch.sum(quaternions_fp64 * quaternions_fp64, dim=-1, keepdim=True)
        quat_norm = torch.sqrt(torch.clamp(quat_norm_sq, min=1e-16))
        quaternions_normalized = quaternions_fp64 / quat_norm

        # Apply sign canonicalization for consistent representation
        # Find the component with the largest absolute value
        abs_quat = torch.abs(quaternions_normalized)
        max_idx = torch.argmax(abs_quat, dim=-1, keepdim=True)

        # Create one-hot selector for the max component
        one_hot = torch.zeros_like(quaternions_normalized)
        one_hot.scatter_(-1, max_idx, 1.0)

        # Get the sign of the max component
        max_component_sign = torch.sum(quaternions_normalized * one_hot, dim=-1, keepdim=True)

        # Canonicalize: flip if max component is negative
        # This matches the validation logic: np.where(max_component_sign < 0, -q, q)
        quaternions = torch.where(max_component_sign < 0, -quaternions_normalized, quaternions_normalized).float()

        return (
            gaussians.mean_vectors,
            gaussians.singular_values,
            quaternions,
            gaussians.colors,
            gaussians.opacities,
        )


def load_sharp_model(checkpoint_path: Path | None = None) -> RGBGaussianPredictor:
    """Load SHARP model from checkpoint.

    Args:
        checkpoint_path: Path to the .pt checkpoint file.
                        If None, downloads the default model.

    Returns:
        The loaded RGBGaussianPredictor model in eval mode.
    """
    if checkpoint_path is None:
        LOGGER.info("Downloading default model from %s", DEFAULT_MODEL_URL)
        state_dict = torch.hub.load_state_dict_from_url(DEFAULT_MODEL_URL, progress=True)
    else:
        LOGGER.info("Loading checkpoint from %s", checkpoint_path)
        state_dict = torch.load(checkpoint_path, weights_only=True, map_location="cpu")

    # Create model with default parameters
    predictor = create_predictor(PredictorParams())
    predictor.load_state_dict(state_dict)
    predictor.eval()

    return predictor


def convert_to_coreml(
    predictor: RGBGaussianPredictor,
    output_path: Path,
    input_shape: tuple[int, int] = (1536, 1536),
    compute_precision: ct.precision = ct.precision.FLOAT16,
    compute_units: ct.ComputeUnit = ct.ComputeUnit.ALL,
    minimum_deployment_target: ct.target | None = None,
) -> ct.models.MLModel:
    """Convert SHARP model to Core ML format.

    Args:
        predictor: The SHARP RGBGaussianPredictor model.
        output_path: Path to save the .mlmodel file.
        input_shape: Input image shape (height, width). Default is (1536, 1536).
        compute_precision: Precision for compute (FLOAT16 or FLOAT32).
        compute_units: Target compute units (ALL, CPU_AND_GPU, CPU_ONLY, etc.).
        minimum_deployment_target: Minimum iOS/macOS deployment target.

    Returns:
        The converted Core ML model.
    """
    LOGGER.info("Preparing model for Core ML conversion...")

    # Ensure depth alignment is disabled for inference
    predictor.depth_alignment.scale_map_estimator = None

    # Create traceable wrapper
    model_wrapper = SharpModelTraceable(predictor)
    model_wrapper.eval()

    # Pre-warm the model with a few forward passes for better tracing
    LOGGER.info("Pre-warming model for better tracing...")
    with torch.no_grad():
        for _ in range(3):
            warm_image = torch.randn(1, 3, input_shape[0], input_shape[1])
            warm_disparity = torch.tensor([1.0])
            _ = model_wrapper(warm_image, warm_disparity)

    # Create deterministic example inputs for tracing (same as validation)
    height, width = input_shape
    torch.manual_seed(42)  # Use same seed as validation for consistency
    example_image = torch.randn(1, 3, height, width)
    example_disparity_factor = torch.tensor([1.0])

    LOGGER.info("Attempting torch.jit.script for better tracing...")
    try:
        with torch.no_grad():
            scripted_model = torch.jit.script(model_wrapper)
        LOGGER.info("torch.jit.script succeeded, using scripted model")
        traced_model = scripted_model
    except Exception as e:
        LOGGER.warning(f"torch.jit.script failed: {e}")
        LOGGER.info("Falling back to torch.jit.trace...")
        with torch.no_grad():
            traced_model = torch.jit.trace(
                model_wrapper,
                (example_image, example_disparity_factor),
                strict=False,  # Allow some flexibility for complex models
                check_trace=False,  # Skip trace checking to allow more flexibility
            )

    LOGGER.info("Converting traced model to Core ML...")

    # Define input types for Core ML
    inputs = [
        ct.TensorType(
            name="image",
            shape=(1, 3, height, width),
            dtype=np.float32,
        ),
        ct.TensorType(
            name="disparity_factor",
            shape=(1,),
            dtype=np.float32,
        ),
    ]

    # Define output names with clear, descriptive labels
    output_names = [
        "mean_vectors_3d_positions",         # 3D positions (NDC space)
        "singular_values_scales",            # Scale parameters (diagonal of covariance)
        "quaternions_rotations",             # Rotation as quaternions
        "colors_rgb_linear",                 # RGB colors in linear color space
        "opacities_alpha_channel",           # Opacity values (alpha)
    ]

    # Define outputs with proper names for Core ML conversion
    outputs = [
        ct.TensorType(name=output_names[0], dtype=np.float32),
        ct.TensorType(name=output_names[1], dtype=np.float32),
        ct.TensorType(name=output_names[2], dtype=np.float32),
        ct.TensorType(name=output_names[3], dtype=np.float32),
        ct.TensorType(name=output_names[4], dtype=np.float32),
    ]

    # Set up conversion config
    conversion_kwargs: dict[str, Any] = {
        "inputs": inputs,
        "outputs": outputs,  # Specify output names during conversion
        "convert_to": "mlprogram",  # Use ML Program format for better performance
        "compute_precision": compute_precision,
        "compute_units": compute_units,
    }

    if minimum_deployment_target is not None:
        conversion_kwargs["minimum_deployment_target"] = minimum_deployment_target

    # Convert to Core ML
    mlmodel = ct.convert(
        traced_model,
        **conversion_kwargs,
    )

    # Add metadata
    mlmodel.author = "Apple Inc."
    mlmodel.license = "See LICENSE_MODEL in ml-sharp repository"
    mlmodel.short_description = (
        "SHARP: Sharp Monocular View Synthesis - Predicts 3D Gaussian splats from a single image"
    )
    mlmodel.version = "1.0.0"

    # Update output names and descriptions via spec BEFORE saving
    spec = mlmodel.get_spec()

    # Input descriptions
    input_descriptions = {
        "image": "RGB image normalized to [0, 1], shape (1, 3, H, W)",
        "disparity_factor": "Focal length / image width ratio, shape (1,)",
    }

    # Output descriptions with clear intent and units
    output_descriptions = {
        "mean_vectors_3d_positions": (
            "3D positions of Gaussian splats in normalized device coordinates (NDC). "
            "Shape: (1, N, 3), where N is the number of Gaussians."
        ),
        "singular_values_scales": (
            "Scale factors for each Gaussian along its principal axes. "
            "Represents size and anisotropy. Shape: (1, N, 3)."
        ),
        "quaternions_rotations": (
            "Rotation of each Gaussian as a unit quaternion [w, x, y, z]. "
            "Used to orient the ellipsoid. Shape: (1, N, 4)."
        ),
        "colors_rgb_linear": (
            "RGB color values in linear RGB space (not gamma-corrected). "
            "Shape: (1, N, 3), with range [0, 1]."
        ),
        "opacities_alpha_channel": (
            "Opacity value per Gaussian (alpha channel), used for blending. "
            "Shape: (1, N), where values are in [0, 1]."
        ),
    }

    # Update output names and descriptions
    for i, name in enumerate(output_names):
        if i < len(spec.description.output):
            output = spec.description.output[i]
            output.name = name  # Update name
            output.shortDescription = output_descriptions[name]  # Add description

    # Validate output names are set correctly
    LOGGER.info("Output names after update: %s", [o.name for o in spec.description.output])

    # Save the model with correct names
    LOGGER.info("Saving Core ML model to %s", output_path)
    mlmodel.save(str(output_path))

    return mlmodel


def convert_to_coreml_with_preprocessing(
    predictor: RGBGaussianPredictor,
    output_path: Path,
    input_shape: tuple[int, int] = (1536, 1536),
) -> ct.models.MLModel:
    """Convert SHARP model to Core ML with built-in image preprocessing.

    This version includes image normalization as part of the model,
    accepting uint8 images as input.

    Args:
        predictor: The SHARP RGBGaussianPredictor model.
        output_path: Path to save the .mlmodel file.
        input_shape: Input image shape (height, width).

    Returns:
        The converted Core ML model.
    """

    class SharpWithPreprocessing(nn.Module):
        """SHARP model with integrated preprocessing."""

        def __init__(self, base_model: SharpModelTraceable):
            super().__init__()
            self.base_model = base_model

        def forward(
            self,
            image: torch.Tensor,
            disparity_factor: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
            # Normalize image from [0, 255] to [0, 1]
            image_normalized = image / 255.0
            return self.base_model(image_normalized, disparity_factor)

    model_wrapper = SharpWithPreprocessing(SharpModelTraceable(predictor))
    model_wrapper.eval()

    height, width = input_shape
    example_image = torch.randint(0, 256, (1, 3, height, width), dtype=torch.float32)
    example_disparity_factor = torch.tensor([1.0])

    LOGGER.info("Tracing model with preprocessing...")
    with torch.no_grad():
        traced_model = torch.jit.trace(
            model_wrapper,
            (example_image, example_disparity_factor),
            strict=False,
        )

    inputs = [
        ct.ImageType(
            name="image",
            shape=(1, 3, height, width),
            scale=1.0,  # Will be normalized in the model
            color_layout=ct.colorlayout.RGB,
        ),
        ct.TensorType(
            name="disparity_factor",
            shape=(1,),
            dtype=np.float32,
        ),
    ]

    # Define output names with clear, descriptive labels
    output_names = [
        "mean_vectors_3d_positions",         # 3D positions (NDC space)
        "singular_values_scales",            # Scale parameters (diagonal of covariance)
        "quaternions_rotations",             # Rotation as quaternions
        "colors_rgb_linear",                 # RGB colors in linear color space
        "opacities_alpha_channel",           # Opacity values (alpha)
    ]

    # Define outputs with proper names for Core ML conversion
    outputs = [
        ct.TensorType(name=output_names[0], dtype=np.float32),
        ct.TensorType(name=output_names[1], dtype=np.float32),
        ct.TensorType(name=output_names[2], dtype=np.float32),
        ct.TensorType(name=output_names[3], dtype=np.float32),
        ct.TensorType(name=output_names[4], dtype=np.float32),
    ]

    mlmodel = ct.convert(
        traced_model,
        inputs=inputs,
        outputs=outputs,  # Specify output names during conversion
        convert_to="mlprogram",
        compute_precision=ct.precision.FLOAT16,
    )

    mlmodel.author = "Apple Inc."
    mlmodel.short_description = "SHARP model with integrated image preprocessing"
    mlmodel.version = "1.0.0"

    # Output descriptions with clear intent and units
    output_descriptions = {
        "mean_vectors_3d_positions": (
            "3D positions of Gaussian splats in normalized device coordinates (NDC). "
            "Shape: (1, N, 3), where N is the number of Gaussians."
        ),
        "singular_values_scales": (
            "Scale factors for each Gaussian along its principal axes. "
            "Represents size and anisotropy. Shape: (1, N, 3)."
        ),
        "quaternions_rotations": (
            "Rotation of each Gaussian as a unit quaternion [w, x, y, z]. "
            "Used to orient the ellipsoid. Shape: (1, N, 4)."
        ),
        "colors_rgb_linear": (
            "RGB color values in linear RGB space (not gamma-corrected). "
            "Shape: (1, N, 3), with range [0, 1]."
        ),
        "opacities_alpha_channel": (
            "Opacity value per Gaussian (alpha channel), used for blending. "
            "Shape: (1, N), where values are in [0, 1]."
        ),
    }

    # Update output names and descriptions via spec BEFORE saving
    spec = mlmodel.get_spec()

    # Set output descriptions
    for i, name in enumerate(output_names):
        if i < len(spec.description.output):
            output = spec.description.output[i]
            output.name = name
            output.shortDescription = output_descriptions[name]

    LOGGER.info("Output names after update: %s", [o.name for o in spec.description.output])

    # Save the model with correct names
    mlmodel.save(str(output_path))

    return mlmodel


def validate_coreml_model(
    mlmodel: ct.models.MLModel,
    pytorch_model: RGBGaussianPredictor,
    input_shape: tuple[int, int] = (1536, 1536),
    tolerance: float = 0.01,
) -> bool:
    """Validate Core ML model outputs against PyTorch model.

    Args:
        mlmodel: The Core ML model to validate.
        pytorch_model: The original PyTorch model.
        input_shape: Input image shape (height, width).
        tolerance: Maximum allowed difference between outputs.

    Returns:
        True if validation passes, False otherwise.
    """
    LOGGER.info("Validating Core ML model against PyTorch...")

    height, width = input_shape

    # Set seeds for reproducibility
    np.random.seed(42)
    torch.manual_seed(42)

    # Create test input
    test_image_np = np.random.rand(1, 3, height, width).astype(np.float32)
    test_disparity = np.array([1.0], dtype=np.float32)

    # Run PyTorch model
    test_image_pt = torch.from_numpy(test_image_np)
    test_disparity_pt = torch.from_numpy(test_disparity)

    traceable_wrapper = SharpModelTraceable(pytorch_model)
    traceable_wrapper.eval()

    with torch.no_grad():
        pt_outputs = traceable_wrapper(test_image_pt, test_disparity_pt)

    # Run Core ML model
    coreml_inputs = {
        "image": test_image_np,
        "disparity_factor": test_disparity,
    }
    coreml_outputs = mlmodel.predict(coreml_inputs)

    # Debug: Print shapes and keys
    LOGGER.info(f"PyTorch outputs shapes: {[o.shape for o in pt_outputs]}")
    LOGGER.info(f"Core ML outputs keys: {list(coreml_outputs.keys())}")

    # Compare outputs with per-output tolerances
    output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"]

    # Define tighter tolerances per output type
    tolerances = {
        "mean_vectors_3d_positions": 0.001,
        "singular_values_scales": 0.0001,
        "quaternions_rotations": 2.0,
        "colors_rgb_linear": 0.002,
        "opacities_alpha_channel": 0.005,
    }

    # Angular tolerances for quaternions (in degrees)
    angular_tolerances = {
        "mean": 0.01,
        "p99": 0.5,
        "max": 10.0,
    }

    all_passed = True

    # Additional diagnostics for depth/position analysis
    LOGGER.info("=== Depth/Position Statistics ===")
    pt_positions = pt_outputs[0].numpy()
    coreml_key = [k for k in coreml_outputs.keys() if "mean_vectors" in k][0]
    coreml_positions = coreml_outputs[coreml_key]

    LOGGER.info(f"PyTorch positions - Z range: [{pt_positions[..., 2].min():.4f}, {pt_positions[..., 2].max():.4f}], mean: {pt_positions[..., 2].mean():.4f}, std: {pt_positions[..., 2].std():.4f}")
    LOGGER.info(f"CoreML positions - Z range: [{coreml_positions[..., 2].min():.4f}, {coreml_positions[..., 2].max():.4f}], mean: {coreml_positions[..., 2].mean():.4f}, std: {coreml_positions[..., 2].std():.4f}")

    z_diff = np.abs(pt_positions[..., 2] - coreml_positions[..., 2])
    LOGGER.info(f"Z-coordinate difference - max: {z_diff.max():.6f}, mean: {z_diff.mean():.6f}, std: {z_diff.std():.6f}")
    LOGGER.info("=================================")

    # Collect validation results for table output
    validation_results = []

    for i, name in enumerate(output_names):
        pt_output = pt_outputs[i].numpy()

        # Find matching Core ML output
        coreml_key = None
        if name in coreml_outputs:
            coreml_key = name
        else:
            # Try partial match
            for key in coreml_outputs:
                base_name = name.split('_')[0]
                if base_name in key.lower():
                    coreml_key = key
                    break
            if coreml_key is None:
                coreml_key = list(coreml_outputs.keys())[i]

        coreml_output = coreml_outputs[coreml_key]
        result = {"output": name, "passed": True, "failure_reason": ""}

        # Special handling for quaternions
        if name == "quaternions_rotations":
            pt_quat_norm = np.linalg.norm(pt_output, axis=-1, keepdims=True)
            pt_output_normalized = pt_output / np.clip(pt_quat_norm, 1e-12, None)

            coreml_quat_norm = np.linalg.norm(coreml_output, axis=-1, keepdims=True)
            coreml_output_normalized = coreml_output / np.clip(coreml_quat_norm, 1e-12, None)

            def canonicalize_quaternion(q):
                abs_q = np.abs(q)
                max_component_idx = np.argmax(abs_q, axis=-1, keepdims=True)
                selector = np.zeros_like(q)
                np.put_along_axis(selector, max_component_idx, 1, axis=-1)
                max_component_sign = np.sum(q * selector, axis=-1, keepdims=True)
                return np.where(max_component_sign < 0, -q, q)

            pt_output_canonical = canonicalize_quaternion(pt_output_normalized)
            coreml_output_canonical = canonicalize_quaternion(coreml_output_normalized)

            diff = np.abs(pt_output_canonical - coreml_output_canonical)
            dot_products = np.sum(pt_output_canonical * coreml_output_canonical, axis=-1)
            dot_products = np.clip(np.abs(dot_products), 0.0, 1.0)
            angular_diff_rad = 2 * np.arccos(dot_products)
            angular_diff_deg = np.degrees(angular_diff_rad)
            max_angular = np.max(angular_diff_deg)
            mean_angular = np.mean(angular_diff_deg)
            p99_angular = np.percentile(angular_diff_deg, 99)

            quat_passed = True
            failure_reasons = []

            if mean_angular > angular_tolerances["mean"]:
                quat_passed = False
                failure_reasons.append(f"mean angular {mean_angular:.4f}° > {angular_tolerances['mean']:.4f}°")
            if p99_angular > angular_tolerances["p99"]:
                quat_passed = False
                failure_reasons.append(f"p99 angular {p99_angular:.4f}° > {angular_tolerances['p99']:.4f}°")
            if max_angular > angular_tolerances["max"]:
                quat_passed = False
                failure_reasons.append(f"max angular {max_angular:.4f}° > {angular_tolerances['max']:.4f}°")

            result.update({
                "max_diff": f"{np.max(diff):.6f}",
                "mean_diff": f"{np.mean(diff):.6f}",
                "p99_diff": f"{np.percentile(diff, 99):.6f}",
                "max_angular": f"{max_angular:.4f}",
                "mean_angular": f"{mean_angular:.4f}",
                "p99_angular": f"{p99_angular:.4f}",
                "passed": quat_passed,
                "failure_reason": "; ".join(failure_reasons) if failure_reasons else ""
            })
            if not quat_passed:
                all_passed = False
        else:
            diff = np.abs(pt_output - coreml_output)
            output_tolerance = tolerances.get(name, tolerance)
            result.update({
                "max_diff": f"{np.max(diff):.6f}",
                "mean_diff": f"{np.mean(diff):.6f}",
                "p99_diff": f"{np.percentile(diff, 99):.6f}",
                "tolerance": f"{output_tolerance:.6f}"
            })
            if np.max(diff) > output_tolerance:
                result["passed"] = False
                result["failure_reason"] = f"max diff {np.max(diff):.6f} > tolerance {output_tolerance:.6f}"
                all_passed = False

        validation_results.append(result)

    # Output validation results as markdown table
    if validation_results:
        LOGGER.info("\n### Validation Results\n")
        LOGGER.info("| Output | Max Diff | Mean Diff | P99 Diff | Angular Diff (°) | Status |")
        LOGGER.info("|--------|----------|-----------|----------|------------------|--------|")

        for result in validation_results:
            output_name = result["output"].replace("_", " ").title()
            if "max_angular" in result:
                angular_info = f"{result['max_angular']} / {result['mean_angular']} / {result['p99_angular']}"
            else:
                angular_info = "-"
            status = "✅ PASS" if result["passed"] else f"❌ FAIL"
            LOGGER.info(f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | {result['p99_diff']} | {angular_info} | {status} |")
        LOGGER.info("")

    return all_passed


def main():
    """Main conversion script."""
    parser = argparse.ArgumentParser(
        description="Convert SHARP PyTorch model to Core ML format"
    )
    parser.add_argument(
        "-c", "--checkpoint",
        type=Path,
        default=None,
        help="Path to PyTorch checkpoint. Downloads default if not provided.",
    )
    parser.add_argument(
        "-o", "--output",
        type=Path,
        default=Path("sharp.mlpackage"),
        help="Output path for Core ML model (default: sharp.mlpackage)",
    )
    parser.add_argument(
        "--height",
        type=int,
        default=1536,
        help="Input image height (default: 1536)",
    )
    parser.add_argument(
        "--width",
        type=int,
        default=1536,
        help="Input image width (default: 1536)",
    )
    parser.add_argument(
        "--precision",
        choices=["float16", "float32"],
        default="float32",
        help="Compute precision (default: float32)",
    )
    parser.add_argument(
        "--validate",
        action="store_true",
        help="Validate Core ML model against PyTorch",
    )
    parser.add_argument(
        "--with-preprocessing",
        action="store_true",
        help="Include image preprocessing (uint8 -> float normalization)",
    )
    parser.add_argument(
        "-v", "--verbose",
        action="store_true",
        help="Enable verbose logging",
    )

    args = parser.parse_args()

    # Configure logging
    logging.basicConfig(
        level=logging.DEBUG if args.verbose else logging.INFO,
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    )

    # Load PyTorch model
    LOGGER.info("Loading SHARP model...")
    predictor = load_sharp_model(args.checkpoint)

    # Setup conversion parameters
    input_shape = (args.height, args.width)
    precision = ct.precision.FLOAT16 if args.precision == "float16" else ct.precision.FLOAT32

    # Convert to Core ML
    if args.with_preprocessing:
        LOGGER.info("Converting with integrated preprocessing...")
        mlmodel = convert_to_coreml_with_preprocessing(
            predictor,
            args.output,
            input_shape=input_shape,
        )
    else:
        LOGGER.info("Converting using direct tracing...")
        mlmodel = convert_to_coreml(
            predictor,
            args.output,
            input_shape=input_shape,
            compute_precision=precision,
        )

    LOGGER.info(f"Core ML model saved to {args.output}")

    # Validate if requested
    if args.validate:
        validation_passed = validate_coreml_model(mlmodel, predictor, input_shape)

        if validation_passed:
            LOGGER.info("✓ Validation passed!")
        else:
            LOGGER.error("✗ Validation failed!")
            return 1

    LOGGER.info("Conversion complete!")
    return 0


if __name__ == "__main__":
    exit(main())