| """Convert SHARP PyTorch model to Core ML .mlmodel format. |
| |
| This script converts the SHARP (Sharp Monocular View Synthesis) model |
| from PyTorch (.pt) to Core ML (.mlmodel) format for deployment on Apple devices. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import logging |
| from pathlib import Path |
| from typing import Any |
|
|
| import coremltools as ct |
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| |
| from sharp.models import PredictorParams, create_predictor |
| from sharp.models.predictor import RGBGaussianPredictor |
|
|
| LOGGER = logging.getLogger(__name__) |
|
|
| DEFAULT_MODEL_URL = "https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt" |
|
|
|
|
| class SafeClamp(nn.Module): |
| """Safe clamp operation that avoids tracing issues.""" |
|
|
| def forward(self, x, min_val=1e-4, max_val=1e4): |
| return torch.clamp(x, min=min_val, max=max_val) |
|
|
|
|
| class SafeDivision(nn.Module): |
| """Safe division that avoids division by zero.""" |
|
|
| def forward(self, numerator, denominator): |
| return numerator / torch.clamp(denominator, min=1e-8) |
|
|
|
|
| class SharpModelTraceable(nn.Module): |
| """Fully traceable version of SHARP for Core ML conversion. |
| |
| This version removes all dynamic control flow and makes the model |
| fully traceable with torch.jit.trace. |
| """ |
|
|
| def __init__(self, predictor: RGBGaussianPredictor): |
| """Initialize the traceable wrapper. |
| |
| Args: |
| predictor: The SHARP RGBGaussianPredictor model. |
| """ |
| super().__init__() |
| |
| self.init_model = predictor.init_model |
| self.feature_model = predictor.feature_model |
| self.monodepth_model = predictor.monodepth_model |
| self.prediction_head = predictor.prediction_head |
| self.gaussian_composer = predictor.gaussian_composer |
| self.depth_alignment = predictor.depth_alignment |
|
|
| |
| self.safe_clamp = SafeClamp() |
| self.safe_div = SafeDivision() |
|
|
| def forward( |
| self, |
| image: torch.Tensor, |
| disparity_factor: torch.Tensor |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Run inference with traceable forward pass. |
| |
| Args: |
| image: Input image tensor of shape (1, 3, H, W) in range [0, 1]. |
| disparity_factor: Disparity factor tensor of shape (1,). |
| |
| Returns: |
| Tuple of 5 tensors representing 3D Gaussians. |
| """ |
| |
| monodepth_output = self.monodepth_model(image) |
| monodepth_disparity = monodepth_output.disparity |
|
|
| |
| |
| disparity_factor_expanded = disparity_factor[:, None, None, None] |
|
|
| |
| disparity_clamped = monodepth_disparity.clamp(min=1e-6, max=1e4) |
| monodepth = disparity_factor_expanded.double() / disparity_clamped.double() |
| monodepth = monodepth.float() |
|
|
| |
| monodepth, _ = self.depth_alignment(monodepth, None, monodepth_output.decoder_features) |
|
|
| |
| init_output = self.init_model(image, monodepth) |
|
|
| |
| image_features = self.feature_model( |
| init_output.feature_input, |
| encodings=monodepth_output.output_features |
| ) |
|
|
| |
| delta_values = self.prediction_head(image_features) |
|
|
| |
| gaussians = self.gaussian_composer( |
| delta=delta_values, |
| base_values=init_output.gaussian_base_values, |
| global_scale=init_output.global_scale, |
| ) |
|
|
| |
| |
| quaternions = gaussians.quaternions |
|
|
| |
| quaternions_fp64 = quaternions.double() |
| quat_norm_sq = torch.sum(quaternions_fp64 * quaternions_fp64, dim=-1, keepdim=True) |
| quat_norm = torch.sqrt(torch.clamp(quat_norm_sq, min=1e-16)) |
| quaternions_normalized = quaternions_fp64 / quat_norm |
|
|
| |
| |
| abs_quat = torch.abs(quaternions_normalized) |
| max_idx = torch.argmax(abs_quat, dim=-1, keepdim=True) |
|
|
| |
| one_hot = torch.zeros_like(quaternions_normalized) |
| one_hot.scatter_(-1, max_idx, 1.0) |
|
|
| |
| max_component_sign = torch.sum(quaternions_normalized * one_hot, dim=-1, keepdim=True) |
|
|
| |
| |
| quaternions = torch.where(max_component_sign < 0, -quaternions_normalized, quaternions_normalized).float() |
|
|
| return ( |
| gaussians.mean_vectors, |
| gaussians.singular_values, |
| quaternions, |
| gaussians.colors, |
| gaussians.opacities, |
| ) |
|
|
|
|
| def load_sharp_model(checkpoint_path: Path | None = None) -> RGBGaussianPredictor: |
| """Load SHARP model from checkpoint. |
| |
| Args: |
| checkpoint_path: Path to the .pt checkpoint file. |
| If None, downloads the default model. |
| |
| Returns: |
| The loaded RGBGaussianPredictor model in eval mode. |
| """ |
| if checkpoint_path is None: |
| LOGGER.info("Downloading default model from %s", DEFAULT_MODEL_URL) |
| state_dict = torch.hub.load_state_dict_from_url(DEFAULT_MODEL_URL, progress=True) |
| else: |
| LOGGER.info("Loading checkpoint from %s", checkpoint_path) |
| state_dict = torch.load(checkpoint_path, weights_only=True, map_location="cpu") |
|
|
| |
| predictor = create_predictor(PredictorParams()) |
| predictor.load_state_dict(state_dict) |
| predictor.eval() |
|
|
| return predictor |
|
|
|
|
| def convert_to_coreml( |
| predictor: RGBGaussianPredictor, |
| output_path: Path, |
| input_shape: tuple[int, int] = (1536, 1536), |
| compute_precision: ct.precision = ct.precision.FLOAT16, |
| compute_units: ct.ComputeUnit = ct.ComputeUnit.ALL, |
| minimum_deployment_target: ct.target | None = None, |
| ) -> ct.models.MLModel: |
| """Convert SHARP model to Core ML format. |
| |
| Args: |
| predictor: The SHARP RGBGaussianPredictor model. |
| output_path: Path to save the .mlmodel file. |
| input_shape: Input image shape (height, width). Default is (1536, 1536). |
| compute_precision: Precision for compute (FLOAT16 or FLOAT32). |
| compute_units: Target compute units (ALL, CPU_AND_GPU, CPU_ONLY, etc.). |
| minimum_deployment_target: Minimum iOS/macOS deployment target. |
| |
| Returns: |
| The converted Core ML model. |
| """ |
| LOGGER.info("Preparing model for Core ML conversion...") |
|
|
| |
| predictor.depth_alignment.scale_map_estimator = None |
|
|
| |
| model_wrapper = SharpModelTraceable(predictor) |
| model_wrapper.eval() |
|
|
| |
| LOGGER.info("Pre-warming model for better tracing...") |
| with torch.no_grad(): |
| for _ in range(3): |
| warm_image = torch.randn(1, 3, input_shape[0], input_shape[1]) |
| warm_disparity = torch.tensor([1.0]) |
| _ = model_wrapper(warm_image, warm_disparity) |
|
|
| |
| height, width = input_shape |
| torch.manual_seed(42) |
| example_image = torch.randn(1, 3, height, width) |
| example_disparity_factor = torch.tensor([1.0]) |
|
|
| LOGGER.info("Attempting torch.jit.script for better tracing...") |
| try: |
| with torch.no_grad(): |
| scripted_model = torch.jit.script(model_wrapper) |
| LOGGER.info("torch.jit.script succeeded, using scripted model") |
| traced_model = scripted_model |
| except Exception as e: |
| LOGGER.warning(f"torch.jit.script failed: {e}") |
| LOGGER.info("Falling back to torch.jit.trace...") |
| with torch.no_grad(): |
| traced_model = torch.jit.trace( |
| model_wrapper, |
| (example_image, example_disparity_factor), |
| strict=False, |
| check_trace=False, |
| ) |
|
|
| LOGGER.info("Converting traced model to Core ML...") |
|
|
| |
| inputs = [ |
| ct.TensorType( |
| name="image", |
| shape=(1, 3, height, width), |
| dtype=np.float32, |
| ), |
| ct.TensorType( |
| name="disparity_factor", |
| shape=(1,), |
| dtype=np.float32, |
| ), |
| ] |
|
|
| |
| output_names = [ |
| "mean_vectors_3d_positions", |
| "singular_values_scales", |
| "quaternions_rotations", |
| "colors_rgb_linear", |
| "opacities_alpha_channel", |
| ] |
|
|
| |
| outputs = [ |
| ct.TensorType(name=output_names[0], dtype=np.float32), |
| ct.TensorType(name=output_names[1], dtype=np.float32), |
| ct.TensorType(name=output_names[2], dtype=np.float32), |
| ct.TensorType(name=output_names[3], dtype=np.float32), |
| ct.TensorType(name=output_names[4], dtype=np.float32), |
| ] |
|
|
| |
| conversion_kwargs: dict[str, Any] = { |
| "inputs": inputs, |
| "outputs": outputs, |
| "convert_to": "mlprogram", |
| "compute_precision": compute_precision, |
| "compute_units": compute_units, |
| } |
|
|
| if minimum_deployment_target is not None: |
| conversion_kwargs["minimum_deployment_target"] = minimum_deployment_target |
|
|
| |
| mlmodel = ct.convert( |
| traced_model, |
| **conversion_kwargs, |
| ) |
|
|
| |
| mlmodel.author = "Apple Inc." |
| mlmodel.license = "See LICENSE_MODEL in ml-sharp repository" |
| mlmodel.short_description = ( |
| "SHARP: Sharp Monocular View Synthesis - Predicts 3D Gaussian splats from a single image" |
| ) |
| mlmodel.version = "1.0.0" |
|
|
| |
| spec = mlmodel.get_spec() |
|
|
| |
| input_descriptions = { |
| "image": "RGB image normalized to [0, 1], shape (1, 3, H, W)", |
| "disparity_factor": "Focal length / image width ratio, shape (1,)", |
| } |
|
|
| |
| output_descriptions = { |
| "mean_vectors_3d_positions": ( |
| "3D positions of Gaussian splats in normalized device coordinates (NDC). " |
| "Shape: (1, N, 3), where N is the number of Gaussians." |
| ), |
| "singular_values_scales": ( |
| "Scale factors for each Gaussian along its principal axes. " |
| "Represents size and anisotropy. Shape: (1, N, 3)." |
| ), |
| "quaternions_rotations": ( |
| "Rotation of each Gaussian as a unit quaternion [w, x, y, z]. " |
| "Used to orient the ellipsoid. Shape: (1, N, 4)." |
| ), |
| "colors_rgb_linear": ( |
| "RGB color values in linear RGB space (not gamma-corrected). " |
| "Shape: (1, N, 3), with range [0, 1]." |
| ), |
| "opacities_alpha_channel": ( |
| "Opacity value per Gaussian (alpha channel), used for blending. " |
| "Shape: (1, N), where values are in [0, 1]." |
| ), |
| } |
|
|
| |
| for i, name in enumerate(output_names): |
| if i < len(spec.description.output): |
| output = spec.description.output[i] |
| output.name = name |
| output.shortDescription = output_descriptions[name] |
|
|
| |
| LOGGER.info("Output names after update: %s", [o.name for o in spec.description.output]) |
|
|
| |
| LOGGER.info("Saving Core ML model to %s", output_path) |
| mlmodel.save(str(output_path)) |
|
|
| return mlmodel |
|
|
|
|
| def convert_to_coreml_with_preprocessing( |
| predictor: RGBGaussianPredictor, |
| output_path: Path, |
| input_shape: tuple[int, int] = (1536, 1536), |
| ) -> ct.models.MLModel: |
| """Convert SHARP model to Core ML with built-in image preprocessing. |
| |
| This version includes image normalization as part of the model, |
| accepting uint8 images as input. |
| |
| Args: |
| predictor: The SHARP RGBGaussianPredictor model. |
| output_path: Path to save the .mlmodel file. |
| input_shape: Input image shape (height, width). |
| |
| Returns: |
| The converted Core ML model. |
| """ |
|
|
| class SharpWithPreprocessing(nn.Module): |
| """SHARP model with integrated preprocessing.""" |
|
|
| def __init__(self, base_model: SharpModelTraceable): |
| super().__init__() |
| self.base_model = base_model |
|
|
| def forward( |
| self, |
| image: torch.Tensor, |
| disparity_factor: torch.Tensor |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| |
| image_normalized = image / 255.0 |
| return self.base_model(image_normalized, disparity_factor) |
|
|
| model_wrapper = SharpWithPreprocessing(SharpModelTraceable(predictor)) |
| model_wrapper.eval() |
|
|
| height, width = input_shape |
| example_image = torch.randint(0, 256, (1, 3, height, width), dtype=torch.float32) |
| example_disparity_factor = torch.tensor([1.0]) |
|
|
| LOGGER.info("Tracing model with preprocessing...") |
| with torch.no_grad(): |
| traced_model = torch.jit.trace( |
| model_wrapper, |
| (example_image, example_disparity_factor), |
| strict=False, |
| ) |
|
|
| inputs = [ |
| ct.ImageType( |
| name="image", |
| shape=(1, 3, height, width), |
| scale=1.0, |
| color_layout=ct.colorlayout.RGB, |
| ), |
| ct.TensorType( |
| name="disparity_factor", |
| shape=(1,), |
| dtype=np.float32, |
| ), |
| ] |
|
|
| |
| output_names = [ |
| "mean_vectors_3d_positions", |
| "singular_values_scales", |
| "quaternions_rotations", |
| "colors_rgb_linear", |
| "opacities_alpha_channel", |
| ] |
|
|
| |
| outputs = [ |
| ct.TensorType(name=output_names[0], dtype=np.float32), |
| ct.TensorType(name=output_names[1], dtype=np.float32), |
| ct.TensorType(name=output_names[2], dtype=np.float32), |
| ct.TensorType(name=output_names[3], dtype=np.float32), |
| ct.TensorType(name=output_names[4], dtype=np.float32), |
| ] |
|
|
| mlmodel = ct.convert( |
| traced_model, |
| inputs=inputs, |
| outputs=outputs, |
| convert_to="mlprogram", |
| compute_precision=ct.precision.FLOAT16, |
| ) |
|
|
| mlmodel.author = "Apple Inc." |
| mlmodel.short_description = "SHARP model with integrated image preprocessing" |
| mlmodel.version = "1.0.0" |
|
|
| |
| output_descriptions = { |
| "mean_vectors_3d_positions": ( |
| "3D positions of Gaussian splats in normalized device coordinates (NDC). " |
| "Shape: (1, N, 3), where N is the number of Gaussians." |
| ), |
| "singular_values_scales": ( |
| "Scale factors for each Gaussian along its principal axes. " |
| "Represents size and anisotropy. Shape: (1, N, 3)." |
| ), |
| "quaternions_rotations": ( |
| "Rotation of each Gaussian as a unit quaternion [w, x, y, z]. " |
| "Used to orient the ellipsoid. Shape: (1, N, 4)." |
| ), |
| "colors_rgb_linear": ( |
| "RGB color values in linear RGB space (not gamma-corrected). " |
| "Shape: (1, N, 3), with range [0, 1]." |
| ), |
| "opacities_alpha_channel": ( |
| "Opacity value per Gaussian (alpha channel), used for blending. " |
| "Shape: (1, N), where values are in [0, 1]." |
| ), |
| } |
|
|
| |
| spec = mlmodel.get_spec() |
|
|
| |
| for i, name in enumerate(output_names): |
| if i < len(spec.description.output): |
| output = spec.description.output[i] |
| output.name = name |
| output.shortDescription = output_descriptions[name] |
|
|
| LOGGER.info("Output names after update: %s", [o.name for o in spec.description.output]) |
|
|
| |
| mlmodel.save(str(output_path)) |
|
|
| return mlmodel |
|
|
|
|
| def validate_coreml_model( |
| mlmodel: ct.models.MLModel, |
| pytorch_model: RGBGaussianPredictor, |
| input_shape: tuple[int, int] = (1536, 1536), |
| tolerance: float = 0.01, |
| ) -> bool: |
| """Validate Core ML model outputs against PyTorch model. |
| |
| Args: |
| mlmodel: The Core ML model to validate. |
| pytorch_model: The original PyTorch model. |
| input_shape: Input image shape (height, width). |
| tolerance: Maximum allowed difference between outputs. |
| |
| Returns: |
| True if validation passes, False otherwise. |
| """ |
| LOGGER.info("Validating Core ML model against PyTorch...") |
|
|
| height, width = input_shape |
|
|
| |
| np.random.seed(42) |
| torch.manual_seed(42) |
|
|
| |
| test_image_np = np.random.rand(1, 3, height, width).astype(np.float32) |
| test_disparity = np.array([1.0], dtype=np.float32) |
|
|
| |
| test_image_pt = torch.from_numpy(test_image_np) |
| test_disparity_pt = torch.from_numpy(test_disparity) |
|
|
| traceable_wrapper = SharpModelTraceable(pytorch_model) |
| traceable_wrapper.eval() |
|
|
| with torch.no_grad(): |
| pt_outputs = traceable_wrapper(test_image_pt, test_disparity_pt) |
|
|
| |
| coreml_inputs = { |
| "image": test_image_np, |
| "disparity_factor": test_disparity, |
| } |
| coreml_outputs = mlmodel.predict(coreml_inputs) |
|
|
| |
| LOGGER.info(f"PyTorch outputs shapes: {[o.shape for o in pt_outputs]}") |
| LOGGER.info(f"Core ML outputs keys: {list(coreml_outputs.keys())}") |
|
|
| |
| output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"] |
|
|
| |
| tolerances = { |
| "mean_vectors_3d_positions": 0.001, |
| "singular_values_scales": 0.0001, |
| "quaternions_rotations": 2.0, |
| "colors_rgb_linear": 0.002, |
| "opacities_alpha_channel": 0.005, |
| } |
|
|
| |
| angular_tolerances = { |
| "mean": 0.01, |
| "p99": 0.5, |
| "max": 10.0, |
| } |
|
|
| all_passed = True |
|
|
| |
| LOGGER.info("=== Depth/Position Statistics ===") |
| pt_positions = pt_outputs[0].numpy() |
| coreml_key = [k for k in coreml_outputs.keys() if "mean_vectors" in k][0] |
| coreml_positions = coreml_outputs[coreml_key] |
|
|
| LOGGER.info(f"PyTorch positions - Z range: [{pt_positions[..., 2].min():.4f}, {pt_positions[..., 2].max():.4f}], mean: {pt_positions[..., 2].mean():.4f}, std: {pt_positions[..., 2].std():.4f}") |
| LOGGER.info(f"CoreML positions - Z range: [{coreml_positions[..., 2].min():.4f}, {coreml_positions[..., 2].max():.4f}], mean: {coreml_positions[..., 2].mean():.4f}, std: {coreml_positions[..., 2].std():.4f}") |
|
|
| z_diff = np.abs(pt_positions[..., 2] - coreml_positions[..., 2]) |
| LOGGER.info(f"Z-coordinate difference - max: {z_diff.max():.6f}, mean: {z_diff.mean():.6f}, std: {z_diff.std():.6f}") |
| LOGGER.info("=================================") |
|
|
| |
| validation_results = [] |
|
|
| for i, name in enumerate(output_names): |
| pt_output = pt_outputs[i].numpy() |
|
|
| |
| coreml_key = None |
| if name in coreml_outputs: |
| coreml_key = name |
| else: |
| |
| for key in coreml_outputs: |
| base_name = name.split('_')[0] |
| if base_name in key.lower(): |
| coreml_key = key |
| break |
| if coreml_key is None: |
| coreml_key = list(coreml_outputs.keys())[i] |
|
|
| coreml_output = coreml_outputs[coreml_key] |
| result = {"output": name, "passed": True, "failure_reason": ""} |
|
|
| |
| if name == "quaternions_rotations": |
| pt_quat_norm = np.linalg.norm(pt_output, axis=-1, keepdims=True) |
| pt_output_normalized = pt_output / np.clip(pt_quat_norm, 1e-12, None) |
|
|
| coreml_quat_norm = np.linalg.norm(coreml_output, axis=-1, keepdims=True) |
| coreml_output_normalized = coreml_output / np.clip(coreml_quat_norm, 1e-12, None) |
|
|
| def canonicalize_quaternion(q): |
| abs_q = np.abs(q) |
| max_component_idx = np.argmax(abs_q, axis=-1, keepdims=True) |
| selector = np.zeros_like(q) |
| np.put_along_axis(selector, max_component_idx, 1, axis=-1) |
| max_component_sign = np.sum(q * selector, axis=-1, keepdims=True) |
| return np.where(max_component_sign < 0, -q, q) |
|
|
| pt_output_canonical = canonicalize_quaternion(pt_output_normalized) |
| coreml_output_canonical = canonicalize_quaternion(coreml_output_normalized) |
|
|
| diff = np.abs(pt_output_canonical - coreml_output_canonical) |
| dot_products = np.sum(pt_output_canonical * coreml_output_canonical, axis=-1) |
| dot_products = np.clip(np.abs(dot_products), 0.0, 1.0) |
| angular_diff_rad = 2 * np.arccos(dot_products) |
| angular_diff_deg = np.degrees(angular_diff_rad) |
| max_angular = np.max(angular_diff_deg) |
| mean_angular = np.mean(angular_diff_deg) |
| p99_angular = np.percentile(angular_diff_deg, 99) |
|
|
| quat_passed = True |
| failure_reasons = [] |
|
|
| if mean_angular > angular_tolerances["mean"]: |
| quat_passed = False |
| failure_reasons.append(f"mean angular {mean_angular:.4f}° > {angular_tolerances['mean']:.4f}°") |
| if p99_angular > angular_tolerances["p99"]: |
| quat_passed = False |
| failure_reasons.append(f"p99 angular {p99_angular:.4f}° > {angular_tolerances['p99']:.4f}°") |
| if max_angular > angular_tolerances["max"]: |
| quat_passed = False |
| failure_reasons.append(f"max angular {max_angular:.4f}° > {angular_tolerances['max']:.4f}°") |
|
|
| result.update({ |
| "max_diff": f"{np.max(diff):.6f}", |
| "mean_diff": f"{np.mean(diff):.6f}", |
| "p99_diff": f"{np.percentile(diff, 99):.6f}", |
| "max_angular": f"{max_angular:.4f}", |
| "mean_angular": f"{mean_angular:.4f}", |
| "p99_angular": f"{p99_angular:.4f}", |
| "passed": quat_passed, |
| "failure_reason": "; ".join(failure_reasons) if failure_reasons else "" |
| }) |
| if not quat_passed: |
| all_passed = False |
| else: |
| diff = np.abs(pt_output - coreml_output) |
| output_tolerance = tolerances.get(name, tolerance) |
| result.update({ |
| "max_diff": f"{np.max(diff):.6f}", |
| "mean_diff": f"{np.mean(diff):.6f}", |
| "p99_diff": f"{np.percentile(diff, 99):.6f}", |
| "tolerance": f"{output_tolerance:.6f}" |
| }) |
| if np.max(diff) > output_tolerance: |
| result["passed"] = False |
| result["failure_reason"] = f"max diff {np.max(diff):.6f} > tolerance {output_tolerance:.6f}" |
| all_passed = False |
|
|
| validation_results.append(result) |
|
|
| |
| if validation_results: |
| LOGGER.info("\n### Validation Results\n") |
| LOGGER.info("| Output | Max Diff | Mean Diff | P99 Diff | Angular Diff (°) | Status |") |
| LOGGER.info("|--------|----------|-----------|----------|------------------|--------|") |
|
|
| for result in validation_results: |
| output_name = result["output"].replace("_", " ").title() |
| if "max_angular" in result: |
| angular_info = f"{result['max_angular']} / {result['mean_angular']} / {result['p99_angular']}" |
| else: |
| angular_info = "-" |
| status = "✅ PASS" if result["passed"] else f"❌ FAIL" |
| LOGGER.info(f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | {result['p99_diff']} | {angular_info} | {status} |") |
| LOGGER.info("") |
|
|
| return all_passed |
|
|
|
|
| def main(): |
| """Main conversion script.""" |
| parser = argparse.ArgumentParser( |
| description="Convert SHARP PyTorch model to Core ML format" |
| ) |
| parser.add_argument( |
| "-c", "--checkpoint", |
| type=Path, |
| default=None, |
| help="Path to PyTorch checkpoint. Downloads default if not provided.", |
| ) |
| parser.add_argument( |
| "-o", "--output", |
| type=Path, |
| default=Path("sharp.mlpackage"), |
| help="Output path for Core ML model (default: sharp.mlpackage)", |
| ) |
| parser.add_argument( |
| "--height", |
| type=int, |
| default=1536, |
| help="Input image height (default: 1536)", |
| ) |
| parser.add_argument( |
| "--width", |
| type=int, |
| default=1536, |
| help="Input image width (default: 1536)", |
| ) |
| parser.add_argument( |
| "--precision", |
| choices=["float16", "float32"], |
| default="float32", |
| help="Compute precision (default: float32)", |
| ) |
| parser.add_argument( |
| "--validate", |
| action="store_true", |
| help="Validate Core ML model against PyTorch", |
| ) |
| parser.add_argument( |
| "--with-preprocessing", |
| action="store_true", |
| help="Include image preprocessing (uint8 -> float normalization)", |
| ) |
| parser.add_argument( |
| "-v", "--verbose", |
| action="store_true", |
| help="Enable verbose logging", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| |
| logging.basicConfig( |
| level=logging.DEBUG if args.verbose else logging.INFO, |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
| ) |
|
|
| |
| LOGGER.info("Loading SHARP model...") |
| predictor = load_sharp_model(args.checkpoint) |
|
|
| |
| input_shape = (args.height, args.width) |
| precision = ct.precision.FLOAT16 if args.precision == "float16" else ct.precision.FLOAT32 |
|
|
| |
| if args.with_preprocessing: |
| LOGGER.info("Converting with integrated preprocessing...") |
| mlmodel = convert_to_coreml_with_preprocessing( |
| predictor, |
| args.output, |
| input_shape=input_shape, |
| ) |
| else: |
| LOGGER.info("Converting using direct tracing...") |
| mlmodel = convert_to_coreml( |
| predictor, |
| args.output, |
| input_shape=input_shape, |
| compute_precision=precision, |
| ) |
|
|
| LOGGER.info(f"Core ML model saved to {args.output}") |
|
|
| |
| if args.validate: |
| validation_passed = validate_coreml_model(mlmodel, predictor, input_shape) |
|
|
| if validation_passed: |
| LOGGER.info("✓ Validation passed!") |
| else: |
| LOGGER.error("✗ Validation failed!") |
| return 1 |
|
|
| LOGGER.info("Conversion complete!") |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| exit(main()) |
|
|