|
|
"""Convert SHARP PyTorch model to Core ML .mlmodel format. |
|
|
|
|
|
This script converts the SHARP (Sharp Monocular View Synthesis) model |
|
|
from PyTorch (.pt) to Core ML (.mlmodel) format for deployment on Apple devices. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Any |
|
|
|
|
|
import coremltools as ct |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
from sharp.models import PredictorParams, create_predictor |
|
|
from sharp.models.predictor import RGBGaussianPredictor |
|
|
|
|
|
LOGGER = logging.getLogger(__name__) |
|
|
|
|
|
DEFAULT_MODEL_URL = "https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt" |
|
|
|
|
|
|
|
|
class SafeClamp(nn.Module): |
|
|
"""Safe clamp operation that avoids tracing issues.""" |
|
|
|
|
|
def forward(self, x, min_val=1e-4, max_val=1e4): |
|
|
return torch.clamp(x, min=min_val, max=max_val) |
|
|
|
|
|
|
|
|
class SafeDivision(nn.Module): |
|
|
"""Safe division that avoids division by zero.""" |
|
|
|
|
|
def forward(self, numerator, denominator): |
|
|
return numerator / torch.clamp(denominator, min=1e-8) |
|
|
|
|
|
|
|
|
class SharpModelTraceable(nn.Module): |
|
|
"""Fully traceable version of SHARP for Core ML conversion. |
|
|
|
|
|
This version removes all dynamic control flow and makes the model |
|
|
fully traceable with torch.jit.trace. |
|
|
""" |
|
|
|
|
|
def __init__(self, predictor: RGBGaussianPredictor): |
|
|
"""Initialize the traceable wrapper. |
|
|
|
|
|
Args: |
|
|
predictor: The SHARP RGBGaussianPredictor model. |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.init_model = predictor.init_model |
|
|
self.feature_model = predictor.feature_model |
|
|
self.monodepth_model = predictor.monodepth_model |
|
|
self.prediction_head = predictor.prediction_head |
|
|
self.gaussian_composer = predictor.gaussian_composer |
|
|
self.depth_alignment = predictor.depth_alignment |
|
|
|
|
|
|
|
|
self.safe_clamp = SafeClamp() |
|
|
self.safe_div = SafeDivision() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
image: torch.Tensor, |
|
|
disparity_factor: torch.Tensor |
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
"""Run inference with traceable forward pass. |
|
|
|
|
|
Args: |
|
|
image: Input image tensor of shape (1, 3, H, W) in range [0, 1]. |
|
|
disparity_factor: Disparity factor tensor of shape (1,). |
|
|
|
|
|
Returns: |
|
|
Tuple of 5 tensors representing 3D Gaussians. |
|
|
""" |
|
|
|
|
|
monodepth_output = self.monodepth_model(image) |
|
|
monodepth_disparity = monodepth_output.disparity |
|
|
|
|
|
|
|
|
|
|
|
disparity_factor_expanded = disparity_factor[:, None, None, None] |
|
|
|
|
|
|
|
|
disparity_clamped = monodepth_disparity.clamp(min=1e-6, max=1e4) |
|
|
monodepth = disparity_factor_expanded.double() / disparity_clamped.double() |
|
|
monodepth = monodepth.float() |
|
|
|
|
|
|
|
|
monodepth, _ = self.depth_alignment(monodepth, None, monodepth_output.decoder_features) |
|
|
|
|
|
|
|
|
init_output = self.init_model(image, monodepth) |
|
|
|
|
|
|
|
|
image_features = self.feature_model( |
|
|
init_output.feature_input, |
|
|
encodings=monodepth_output.output_features |
|
|
) |
|
|
|
|
|
|
|
|
delta_values = self.prediction_head(image_features) |
|
|
|
|
|
|
|
|
gaussians = self.gaussian_composer( |
|
|
delta=delta_values, |
|
|
base_values=init_output.gaussian_base_values, |
|
|
global_scale=init_output.global_scale, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
quaternions = gaussians.quaternions |
|
|
|
|
|
|
|
|
quaternions_fp64 = quaternions.double() |
|
|
quat_norm_sq = torch.sum(quaternions_fp64 * quaternions_fp64, dim=-1, keepdim=True) |
|
|
quat_norm = torch.sqrt(torch.clamp(quat_norm_sq, min=1e-16)) |
|
|
quaternions_normalized = quaternions_fp64 / quat_norm |
|
|
|
|
|
|
|
|
|
|
|
abs_quat = torch.abs(quaternions_normalized) |
|
|
max_idx = torch.argmax(abs_quat, dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
one_hot = torch.zeros_like(quaternions_normalized) |
|
|
one_hot.scatter_(-1, max_idx, 1.0) |
|
|
|
|
|
|
|
|
max_component_sign = torch.sum(quaternions_normalized * one_hot, dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
|
|
|
quaternions = torch.where(max_component_sign < 0, -quaternions_normalized, quaternions_normalized).float() |
|
|
|
|
|
return ( |
|
|
gaussians.mean_vectors, |
|
|
gaussians.singular_values, |
|
|
quaternions, |
|
|
gaussians.colors, |
|
|
gaussians.opacities, |
|
|
) |
|
|
|
|
|
|
|
|
def load_sharp_model(checkpoint_path: Path | None = None) -> RGBGaussianPredictor: |
|
|
"""Load SHARP model from checkpoint. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to the .pt checkpoint file. |
|
|
If None, downloads the default model. |
|
|
|
|
|
Returns: |
|
|
The loaded RGBGaussianPredictor model in eval mode. |
|
|
""" |
|
|
if checkpoint_path is None: |
|
|
LOGGER.info("Downloading default model from %s", DEFAULT_MODEL_URL) |
|
|
state_dict = torch.hub.load_state_dict_from_url(DEFAULT_MODEL_URL, progress=True) |
|
|
else: |
|
|
LOGGER.info("Loading checkpoint from %s", checkpoint_path) |
|
|
state_dict = torch.load(checkpoint_path, weights_only=True, map_location="cpu") |
|
|
|
|
|
|
|
|
predictor = create_predictor(PredictorParams()) |
|
|
predictor.load_state_dict(state_dict) |
|
|
predictor.eval() |
|
|
|
|
|
return predictor |
|
|
|
|
|
|
|
|
def convert_to_coreml( |
|
|
predictor: RGBGaussianPredictor, |
|
|
output_path: Path, |
|
|
input_shape: tuple[int, int] = (1536, 1536), |
|
|
compute_precision: ct.precision = ct.precision.FLOAT16, |
|
|
compute_units: ct.ComputeUnit = ct.ComputeUnit.ALL, |
|
|
minimum_deployment_target: ct.target | None = None, |
|
|
) -> ct.models.MLModel: |
|
|
"""Convert SHARP model to Core ML format. |
|
|
|
|
|
Args: |
|
|
predictor: The SHARP RGBGaussianPredictor model. |
|
|
output_path: Path to save the .mlmodel file. |
|
|
input_shape: Input image shape (height, width). Default is (1536, 1536). |
|
|
compute_precision: Precision for compute (FLOAT16 or FLOAT32). |
|
|
compute_units: Target compute units (ALL, CPU_AND_GPU, CPU_ONLY, etc.). |
|
|
minimum_deployment_target: Minimum iOS/macOS deployment target. |
|
|
|
|
|
Returns: |
|
|
The converted Core ML model. |
|
|
""" |
|
|
LOGGER.info("Preparing model for Core ML conversion...") |
|
|
|
|
|
|
|
|
predictor.depth_alignment.scale_map_estimator = None |
|
|
|
|
|
|
|
|
model_wrapper = SharpModelTraceable(predictor) |
|
|
model_wrapper.eval() |
|
|
|
|
|
|
|
|
LOGGER.info("Pre-warming model for better tracing...") |
|
|
with torch.no_grad(): |
|
|
for _ in range(3): |
|
|
warm_image = torch.randn(1, 3, input_shape[0], input_shape[1]) |
|
|
warm_disparity = torch.tensor([1.0]) |
|
|
_ = model_wrapper(warm_image, warm_disparity) |
|
|
|
|
|
|
|
|
height, width = input_shape |
|
|
torch.manual_seed(42) |
|
|
example_image = torch.randn(1, 3, height, width) |
|
|
example_disparity_factor = torch.tensor([1.0]) |
|
|
|
|
|
LOGGER.info("Attempting torch.jit.script for better tracing...") |
|
|
try: |
|
|
with torch.no_grad(): |
|
|
scripted_model = torch.jit.script(model_wrapper) |
|
|
LOGGER.info("torch.jit.script succeeded, using scripted model") |
|
|
traced_model = scripted_model |
|
|
except Exception as e: |
|
|
LOGGER.warning(f"torch.jit.script failed: {e}") |
|
|
LOGGER.info("Falling back to torch.jit.trace...") |
|
|
with torch.no_grad(): |
|
|
traced_model = torch.jit.trace( |
|
|
model_wrapper, |
|
|
(example_image, example_disparity_factor), |
|
|
strict=False, |
|
|
check_trace=False, |
|
|
) |
|
|
|
|
|
LOGGER.info("Converting traced model to Core ML...") |
|
|
|
|
|
|
|
|
inputs = [ |
|
|
ct.TensorType( |
|
|
name="image", |
|
|
shape=(1, 3, height, width), |
|
|
dtype=np.float32, |
|
|
), |
|
|
ct.TensorType( |
|
|
name="disparity_factor", |
|
|
shape=(1,), |
|
|
dtype=np.float32, |
|
|
), |
|
|
] |
|
|
|
|
|
|
|
|
output_names = [ |
|
|
"mean_vectors_3d_positions", |
|
|
"singular_values_scales", |
|
|
"quaternions_rotations", |
|
|
"colors_rgb_linear", |
|
|
"opacities_alpha_channel", |
|
|
] |
|
|
|
|
|
|
|
|
outputs = [ |
|
|
ct.TensorType(name=output_names[0], dtype=np.float32), |
|
|
ct.TensorType(name=output_names[1], dtype=np.float32), |
|
|
ct.TensorType(name=output_names[2], dtype=np.float32), |
|
|
ct.TensorType(name=output_names[3], dtype=np.float32), |
|
|
ct.TensorType(name=output_names[4], dtype=np.float32), |
|
|
] |
|
|
|
|
|
|
|
|
conversion_kwargs: dict[str, Any] = { |
|
|
"inputs": inputs, |
|
|
"outputs": outputs, |
|
|
"convert_to": "mlprogram", |
|
|
"compute_precision": compute_precision, |
|
|
"compute_units": compute_units, |
|
|
} |
|
|
|
|
|
if minimum_deployment_target is not None: |
|
|
conversion_kwargs["minimum_deployment_target"] = minimum_deployment_target |
|
|
|
|
|
|
|
|
mlmodel = ct.convert( |
|
|
traced_model, |
|
|
**conversion_kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
mlmodel.author = "Apple Inc." |
|
|
mlmodel.license = "See LICENSE_MODEL in ml-sharp repository" |
|
|
mlmodel.short_description = ( |
|
|
"SHARP: Sharp Monocular View Synthesis - Predicts 3D Gaussian splats from a single image" |
|
|
) |
|
|
mlmodel.version = "1.0.0" |
|
|
|
|
|
|
|
|
spec = mlmodel.get_spec() |
|
|
|
|
|
|
|
|
input_descriptions = { |
|
|
"image": "RGB image normalized to [0, 1], shape (1, 3, H, W)", |
|
|
"disparity_factor": "Focal length / image width ratio, shape (1,)", |
|
|
} |
|
|
|
|
|
|
|
|
output_descriptions = { |
|
|
"mean_vectors_3d_positions": ( |
|
|
"3D positions of Gaussian splats in normalized device coordinates (NDC). " |
|
|
"Shape: (1, N, 3), where N is the number of Gaussians." |
|
|
), |
|
|
"singular_values_scales": ( |
|
|
"Scale factors for each Gaussian along its principal axes. " |
|
|
"Represents size and anisotropy. Shape: (1, N, 3)." |
|
|
), |
|
|
"quaternions_rotations": ( |
|
|
"Rotation of each Gaussian as a unit quaternion [w, x, y, z]. " |
|
|
"Used to orient the ellipsoid. Shape: (1, N, 4)." |
|
|
), |
|
|
"colors_rgb_linear": ( |
|
|
"RGB color values in linear RGB space (not gamma-corrected). " |
|
|
"Shape: (1, N, 3), with range [0, 1]." |
|
|
), |
|
|
"opacities_alpha_channel": ( |
|
|
"Opacity value per Gaussian (alpha channel), used for blending. " |
|
|
"Shape: (1, N), where values are in [0, 1]." |
|
|
), |
|
|
} |
|
|
|
|
|
|
|
|
for i, name in enumerate(output_names): |
|
|
if i < len(spec.description.output): |
|
|
output = spec.description.output[i] |
|
|
output.name = name |
|
|
output.shortDescription = output_descriptions[name] |
|
|
|
|
|
|
|
|
LOGGER.info("Output names after update: %s", [o.name for o in spec.description.output]) |
|
|
|
|
|
|
|
|
LOGGER.info("Saving Core ML model to %s", output_path) |
|
|
mlmodel.save(str(output_path)) |
|
|
|
|
|
return mlmodel |
|
|
|
|
|
|
|
|
def convert_to_coreml_with_preprocessing( |
|
|
predictor: RGBGaussianPredictor, |
|
|
output_path: Path, |
|
|
input_shape: tuple[int, int] = (1536, 1536), |
|
|
) -> ct.models.MLModel: |
|
|
"""Convert SHARP model to Core ML with built-in image preprocessing. |
|
|
|
|
|
This version includes image normalization as part of the model, |
|
|
accepting uint8 images as input. |
|
|
|
|
|
Args: |
|
|
predictor: The SHARP RGBGaussianPredictor model. |
|
|
output_path: Path to save the .mlmodel file. |
|
|
input_shape: Input image shape (height, width). |
|
|
|
|
|
Returns: |
|
|
The converted Core ML model. |
|
|
""" |
|
|
|
|
|
class SharpWithPreprocessing(nn.Module): |
|
|
"""SHARP model with integrated preprocessing.""" |
|
|
|
|
|
def __init__(self, base_model: SharpModelTraceable): |
|
|
super().__init__() |
|
|
self.base_model = base_model |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
image: torch.Tensor, |
|
|
disparity_factor: torch.Tensor |
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
|
|
|
image_normalized = image / 255.0 |
|
|
return self.base_model(image_normalized, disparity_factor) |
|
|
|
|
|
model_wrapper = SharpWithPreprocessing(SharpModelTraceable(predictor)) |
|
|
model_wrapper.eval() |
|
|
|
|
|
height, width = input_shape |
|
|
example_image = torch.randint(0, 256, (1, 3, height, width), dtype=torch.float32) |
|
|
example_disparity_factor = torch.tensor([1.0]) |
|
|
|
|
|
LOGGER.info("Tracing model with preprocessing...") |
|
|
with torch.no_grad(): |
|
|
traced_model = torch.jit.trace( |
|
|
model_wrapper, |
|
|
(example_image, example_disparity_factor), |
|
|
strict=False, |
|
|
) |
|
|
|
|
|
inputs = [ |
|
|
ct.ImageType( |
|
|
name="image", |
|
|
shape=(1, 3, height, width), |
|
|
scale=1.0, |
|
|
color_layout=ct.colorlayout.RGB, |
|
|
), |
|
|
ct.TensorType( |
|
|
name="disparity_factor", |
|
|
shape=(1,), |
|
|
dtype=np.float32, |
|
|
), |
|
|
] |
|
|
|
|
|
|
|
|
output_names = [ |
|
|
"mean_vectors_3d_positions", |
|
|
"singular_values_scales", |
|
|
"quaternions_rotations", |
|
|
"colors_rgb_linear", |
|
|
"opacities_alpha_channel", |
|
|
] |
|
|
|
|
|
|
|
|
outputs = [ |
|
|
ct.TensorType(name=output_names[0], dtype=np.float32), |
|
|
ct.TensorType(name=output_names[1], dtype=np.float32), |
|
|
ct.TensorType(name=output_names[2], dtype=np.float32), |
|
|
ct.TensorType(name=output_names[3], dtype=np.float32), |
|
|
ct.TensorType(name=output_names[4], dtype=np.float32), |
|
|
] |
|
|
|
|
|
mlmodel = ct.convert( |
|
|
traced_model, |
|
|
inputs=inputs, |
|
|
outputs=outputs, |
|
|
convert_to="mlprogram", |
|
|
compute_precision=ct.precision.FLOAT16, |
|
|
) |
|
|
|
|
|
mlmodel.author = "Apple Inc." |
|
|
mlmodel.short_description = "SHARP model with integrated image preprocessing" |
|
|
mlmodel.version = "1.0.0" |
|
|
|
|
|
|
|
|
output_descriptions = { |
|
|
"mean_vectors_3d_positions": ( |
|
|
"3D positions of Gaussian splats in normalized device coordinates (NDC). " |
|
|
"Shape: (1, N, 3), where N is the number of Gaussians." |
|
|
), |
|
|
"singular_values_scales": ( |
|
|
"Scale factors for each Gaussian along its principal axes. " |
|
|
"Represents size and anisotropy. Shape: (1, N, 3)." |
|
|
), |
|
|
"quaternions_rotations": ( |
|
|
"Rotation of each Gaussian as a unit quaternion [w, x, y, z]. " |
|
|
"Used to orient the ellipsoid. Shape: (1, N, 4)." |
|
|
), |
|
|
"colors_rgb_linear": ( |
|
|
"RGB color values in linear RGB space (not gamma-corrected). " |
|
|
"Shape: (1, N, 3), with range [0, 1]." |
|
|
), |
|
|
"opacities_alpha_channel": ( |
|
|
"Opacity value per Gaussian (alpha channel), used for blending. " |
|
|
"Shape: (1, N), where values are in [0, 1]." |
|
|
), |
|
|
} |
|
|
|
|
|
|
|
|
spec = mlmodel.get_spec() |
|
|
|
|
|
|
|
|
for i, name in enumerate(output_names): |
|
|
if i < len(spec.description.output): |
|
|
output = spec.description.output[i] |
|
|
output.name = name |
|
|
output.shortDescription = output_descriptions[name] |
|
|
|
|
|
LOGGER.info("Output names after update: %s", [o.name for o in spec.description.output]) |
|
|
|
|
|
|
|
|
mlmodel.save(str(output_path)) |
|
|
|
|
|
return mlmodel |
|
|
|
|
|
|
|
|
def validate_coreml_model( |
|
|
mlmodel: ct.models.MLModel, |
|
|
pytorch_model: RGBGaussianPredictor, |
|
|
input_shape: tuple[int, int] = (1536, 1536), |
|
|
tolerance: float = 0.01, |
|
|
) -> bool: |
|
|
"""Validate Core ML model outputs against PyTorch model. |
|
|
|
|
|
Args: |
|
|
mlmodel: The Core ML model to validate. |
|
|
pytorch_model: The original PyTorch model. |
|
|
input_shape: Input image shape (height, width). |
|
|
tolerance: Maximum allowed difference between outputs. |
|
|
|
|
|
Returns: |
|
|
True if validation passes, False otherwise. |
|
|
""" |
|
|
LOGGER.info("Validating Core ML model against PyTorch...") |
|
|
|
|
|
height, width = input_shape |
|
|
|
|
|
|
|
|
np.random.seed(42) |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
test_image_np = np.random.rand(1, 3, height, width).astype(np.float32) |
|
|
test_disparity = np.array([1.0], dtype=np.float32) |
|
|
|
|
|
|
|
|
test_image_pt = torch.from_numpy(test_image_np) |
|
|
test_disparity_pt = torch.from_numpy(test_disparity) |
|
|
|
|
|
traceable_wrapper = SharpModelTraceable(pytorch_model) |
|
|
traceable_wrapper.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
pt_outputs = traceable_wrapper(test_image_pt, test_disparity_pt) |
|
|
|
|
|
|
|
|
coreml_inputs = { |
|
|
"image": test_image_np, |
|
|
"disparity_factor": test_disparity, |
|
|
} |
|
|
coreml_outputs = mlmodel.predict(coreml_inputs) |
|
|
|
|
|
|
|
|
LOGGER.info(f"PyTorch outputs shapes: {[o.shape for o in pt_outputs]}") |
|
|
LOGGER.info(f"Core ML outputs keys: {list(coreml_outputs.keys())}") |
|
|
|
|
|
|
|
|
output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"] |
|
|
|
|
|
|
|
|
tolerances = { |
|
|
"mean_vectors_3d_positions": 0.001, |
|
|
"singular_values_scales": 0.0001, |
|
|
"quaternions_rotations": 2.0, |
|
|
"colors_rgb_linear": 0.002, |
|
|
"opacities_alpha_channel": 0.005, |
|
|
} |
|
|
|
|
|
|
|
|
angular_tolerances = { |
|
|
"mean": 0.01, |
|
|
"p99": 0.5, |
|
|
"max": 10.0, |
|
|
} |
|
|
|
|
|
all_passed = True |
|
|
|
|
|
|
|
|
LOGGER.info("=== Depth/Position Statistics ===") |
|
|
pt_positions = pt_outputs[0].numpy() |
|
|
coreml_key = [k for k in coreml_outputs.keys() if "mean_vectors" in k][0] |
|
|
coreml_positions = coreml_outputs[coreml_key] |
|
|
|
|
|
LOGGER.info(f"PyTorch positions - Z range: [{pt_positions[..., 2].min():.4f}, {pt_positions[..., 2].max():.4f}], mean: {pt_positions[..., 2].mean():.4f}, std: {pt_positions[..., 2].std():.4f}") |
|
|
LOGGER.info(f"CoreML positions - Z range: [{coreml_positions[..., 2].min():.4f}, {coreml_positions[..., 2].max():.4f}], mean: {coreml_positions[..., 2].mean():.4f}, std: {coreml_positions[..., 2].std():.4f}") |
|
|
|
|
|
z_diff = np.abs(pt_positions[..., 2] - coreml_positions[..., 2]) |
|
|
LOGGER.info(f"Z-coordinate difference - max: {z_diff.max():.6f}, mean: {z_diff.mean():.6f}, std: {z_diff.std():.6f}") |
|
|
LOGGER.info("=================================") |
|
|
|
|
|
|
|
|
validation_results = [] |
|
|
|
|
|
for i, name in enumerate(output_names): |
|
|
pt_output = pt_outputs[i].numpy() |
|
|
|
|
|
|
|
|
coreml_key = None |
|
|
if name in coreml_outputs: |
|
|
coreml_key = name |
|
|
else: |
|
|
|
|
|
for key in coreml_outputs: |
|
|
base_name = name.split('_')[0] |
|
|
if base_name in key.lower(): |
|
|
coreml_key = key |
|
|
break |
|
|
if coreml_key is None: |
|
|
coreml_key = list(coreml_outputs.keys())[i] |
|
|
|
|
|
coreml_output = coreml_outputs[coreml_key] |
|
|
result = {"output": name, "passed": True, "failure_reason": ""} |
|
|
|
|
|
|
|
|
if name == "quaternions_rotations": |
|
|
pt_quat_norm = np.linalg.norm(pt_output, axis=-1, keepdims=True) |
|
|
pt_output_normalized = pt_output / np.clip(pt_quat_norm, 1e-12, None) |
|
|
|
|
|
coreml_quat_norm = np.linalg.norm(coreml_output, axis=-1, keepdims=True) |
|
|
coreml_output_normalized = coreml_output / np.clip(coreml_quat_norm, 1e-12, None) |
|
|
|
|
|
def canonicalize_quaternion(q): |
|
|
abs_q = np.abs(q) |
|
|
max_component_idx = np.argmax(abs_q, axis=-1, keepdims=True) |
|
|
selector = np.zeros_like(q) |
|
|
np.put_along_axis(selector, max_component_idx, 1, axis=-1) |
|
|
max_component_sign = np.sum(q * selector, axis=-1, keepdims=True) |
|
|
return np.where(max_component_sign < 0, -q, q) |
|
|
|
|
|
pt_output_canonical = canonicalize_quaternion(pt_output_normalized) |
|
|
coreml_output_canonical = canonicalize_quaternion(coreml_output_normalized) |
|
|
|
|
|
diff = np.abs(pt_output_canonical - coreml_output_canonical) |
|
|
dot_products = np.sum(pt_output_canonical * coreml_output_canonical, axis=-1) |
|
|
dot_products = np.clip(np.abs(dot_products), 0.0, 1.0) |
|
|
angular_diff_rad = 2 * np.arccos(dot_products) |
|
|
angular_diff_deg = np.degrees(angular_diff_rad) |
|
|
max_angular = np.max(angular_diff_deg) |
|
|
mean_angular = np.mean(angular_diff_deg) |
|
|
p99_angular = np.percentile(angular_diff_deg, 99) |
|
|
|
|
|
quat_passed = True |
|
|
failure_reasons = [] |
|
|
|
|
|
if mean_angular > angular_tolerances["mean"]: |
|
|
quat_passed = False |
|
|
failure_reasons.append(f"mean angular {mean_angular:.4f}° > {angular_tolerances['mean']:.4f}°") |
|
|
if p99_angular > angular_tolerances["p99"]: |
|
|
quat_passed = False |
|
|
failure_reasons.append(f"p99 angular {p99_angular:.4f}° > {angular_tolerances['p99']:.4f}°") |
|
|
if max_angular > angular_tolerances["max"]: |
|
|
quat_passed = False |
|
|
failure_reasons.append(f"max angular {max_angular:.4f}° > {angular_tolerances['max']:.4f}°") |
|
|
|
|
|
result.update({ |
|
|
"max_diff": f"{np.max(diff):.6f}", |
|
|
"mean_diff": f"{np.mean(diff):.6f}", |
|
|
"p99_diff": f"{np.percentile(diff, 99):.6f}", |
|
|
"max_angular": f"{max_angular:.4f}", |
|
|
"mean_angular": f"{mean_angular:.4f}", |
|
|
"p99_angular": f"{p99_angular:.4f}", |
|
|
"passed": quat_passed, |
|
|
"failure_reason": "; ".join(failure_reasons) if failure_reasons else "" |
|
|
}) |
|
|
if not quat_passed: |
|
|
all_passed = False |
|
|
else: |
|
|
diff = np.abs(pt_output - coreml_output) |
|
|
output_tolerance = tolerances.get(name, tolerance) |
|
|
result.update({ |
|
|
"max_diff": f"{np.max(diff):.6f}", |
|
|
"mean_diff": f"{np.mean(diff):.6f}", |
|
|
"p99_diff": f"{np.percentile(diff, 99):.6f}", |
|
|
"tolerance": f"{output_tolerance:.6f}" |
|
|
}) |
|
|
if np.max(diff) > output_tolerance: |
|
|
result["passed"] = False |
|
|
result["failure_reason"] = f"max diff {np.max(diff):.6f} > tolerance {output_tolerance:.6f}" |
|
|
all_passed = False |
|
|
|
|
|
validation_results.append(result) |
|
|
|
|
|
|
|
|
if validation_results: |
|
|
LOGGER.info("\n### Validation Results\n") |
|
|
LOGGER.info("| Output | Max Diff | Mean Diff | P99 Diff | Angular Diff (°) | Status |") |
|
|
LOGGER.info("|--------|----------|-----------|----------|------------------|--------|") |
|
|
|
|
|
for result in validation_results: |
|
|
output_name = result["output"].replace("_", " ").title() |
|
|
if "max_angular" in result: |
|
|
angular_info = f"{result['max_angular']} / {result['mean_angular']} / {result['p99_angular']}" |
|
|
else: |
|
|
angular_info = "-" |
|
|
status = "✅ PASS" if result["passed"] else f"❌ FAIL" |
|
|
LOGGER.info(f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | {result['p99_diff']} | {angular_info} | {status} |") |
|
|
LOGGER.info("") |
|
|
|
|
|
return all_passed |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main conversion script.""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Convert SHARP PyTorch model to Core ML format" |
|
|
) |
|
|
parser.add_argument( |
|
|
"-c", "--checkpoint", |
|
|
type=Path, |
|
|
default=None, |
|
|
help="Path to PyTorch checkpoint. Downloads default if not provided.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"-o", "--output", |
|
|
type=Path, |
|
|
default=Path("sharp.mlpackage"), |
|
|
help="Output path for Core ML model (default: sharp.mlpackage)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--height", |
|
|
type=int, |
|
|
default=1536, |
|
|
help="Input image height (default: 1536)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--width", |
|
|
type=int, |
|
|
default=1536, |
|
|
help="Input image width (default: 1536)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--precision", |
|
|
choices=["float16", "float32"], |
|
|
default="float32", |
|
|
help="Compute precision (default: float32)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--validate", |
|
|
action="store_true", |
|
|
help="Validate Core ML model against PyTorch", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--with-preprocessing", |
|
|
action="store_true", |
|
|
help="Include image preprocessing (uint8 -> float normalization)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"-v", "--verbose", |
|
|
action="store_true", |
|
|
help="Enable verbose logging", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.DEBUG if args.verbose else logging.INFO, |
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
|
) |
|
|
|
|
|
|
|
|
LOGGER.info("Loading SHARP model...") |
|
|
predictor = load_sharp_model(args.checkpoint) |
|
|
|
|
|
|
|
|
input_shape = (args.height, args.width) |
|
|
precision = ct.precision.FLOAT16 if args.precision == "float16" else ct.precision.FLOAT32 |
|
|
|
|
|
|
|
|
if args.with_preprocessing: |
|
|
LOGGER.info("Converting with integrated preprocessing...") |
|
|
mlmodel = convert_to_coreml_with_preprocessing( |
|
|
predictor, |
|
|
args.output, |
|
|
input_shape=input_shape, |
|
|
) |
|
|
else: |
|
|
LOGGER.info("Converting using direct tracing...") |
|
|
mlmodel = convert_to_coreml( |
|
|
predictor, |
|
|
args.output, |
|
|
input_shape=input_shape, |
|
|
compute_precision=precision, |
|
|
) |
|
|
|
|
|
LOGGER.info(f"Core ML model saved to {args.output}") |
|
|
|
|
|
|
|
|
if args.validate: |
|
|
validation_passed = validate_coreml_model(mlmodel, predictor, input_shape) |
|
|
|
|
|
if validation_passed: |
|
|
LOGGER.info("✓ Validation passed!") |
|
|
else: |
|
|
LOGGER.error("✗ Validation failed!") |
|
|
return 1 |
|
|
|
|
|
LOGGER.info("Conversion complete!") |
|
|
return 0 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
exit(main()) |
|
|
|