| | """Convert SHARP PyTorch model to Core ML .mlmodel format. |
| | |
| | This script converts the SHARP (Sharp Monocular View Synthesis) model |
| | from PyTorch (.pt) to Core ML (.mlmodel) format for deployment on Apple devices. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import argparse |
| | import logging |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | from typing import Any |
| |
|
| | import coremltools as ct |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | from PIL import Image |
| |
|
| | |
| | from sharp.models import PredictorParams, create_predictor |
| | from sharp.models.predictor import RGBGaussianPredictor |
| | from sharp.utils import io |
| |
|
| | LOGGER = logging.getLogger(__name__) |
| |
|
| | DEFAULT_MODEL_URL = "https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt" |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | OUTPUT_NAMES = [ |
| | "mean_vectors_3d_positions", |
| | "singular_values_scales", |
| | "quaternions_rotations", |
| | "colors_rgb_linear", |
| | "opacities_alpha_channel", |
| | ] |
| |
|
| | |
| | OUTPUT_DESCRIPTIONS = { |
| | "mean_vectors_3d_positions": ( |
| | "3D positions of Gaussian splats in normalized device coordinates (NDC). " |
| | "Shape: (1, N, 3), where N is the number of Gaussians." |
| | ), |
| | "singular_values_scales": ( |
| | "Scale factors for each Gaussian along its principal axes. " |
| | "Represents size and anisotropy. Shape: (1, N, 3)." |
| | ), |
| | "quaternions_rotations": ( |
| | "Rotation of each Gaussian as a unit quaternion [w, x, y, z]. " |
| | "Used to orient the ellipsoid. Shape: (1, N, 4)." |
| | ), |
| | "colors_rgb_linear": ( |
| | "RGB color values in linear RGB space (not gamma-corrected). " |
| | "Shape: (1, N, 3), with range [0, 1]." |
| | ), |
| | "opacities_alpha_channel": ( |
| | "Opacity value per Gaussian (alpha channel), used for blending. " |
| | "Shape: (1, N), where values are in [0, 1]." |
| | ), |
| | } |
| |
|
| |
|
| | @dataclass |
| | class ToleranceConfig: |
| | """Tolerance configuration for validation.""" |
| | |
| | |
| | random_tolerances: dict[str, float] = None |
| | |
| | |
| | image_tolerances: dict[str, float] = None |
| | |
| | |
| | angular_tolerances_random: dict[str, float] = None |
| | angular_tolerances_image: dict[str, float] = None |
| | |
| | def __post_init__(self): |
| | if self.random_tolerances is None: |
| | self.random_tolerances = { |
| | "mean_vectors_3d_positions": 0.001, |
| | "singular_values_scales": 0.0001, |
| | "quaternions_rotations": 2.0, |
| | "colors_rgb_linear": 0.002, |
| | "opacities_alpha_channel": 0.005, |
| | } |
| | |
| | if self.image_tolerances is None: |
| | self.image_tolerances = { |
| | "mean_vectors_3d_positions": 3.5, |
| | "singular_values_scales": 0.035, |
| | "quaternions_rotations": 5.0, |
| | "colors_rgb_linear": 0.01, |
| | "opacities_alpha_channel": 0.05, |
| | } |
| | |
| | if self.angular_tolerances_random is None: |
| | self.angular_tolerances_random = { |
| | "mean": 0.01, |
| | "p99": 0.1, |
| | "p99_9": 1.0, |
| | "max": 5.0, |
| | } |
| | |
| | if self.angular_tolerances_image is None: |
| | self.angular_tolerances_image = { |
| | "mean": 0.2, |
| | "p99": 2.0, |
| | "p99_9": 5.0, |
| | "max": 25.0, |
| | } |
| |
|
| |
|
| | class SharpModelTraceable(nn.Module): |
| | """Fully traceable version of SHARP for Core ML conversion. |
| | |
| | This version removes all dynamic control flow and makes the model |
| | fully traceable with torch.jit.trace. |
| | """ |
| |
|
| | def __init__(self, predictor: RGBGaussianPredictor): |
| | """Initialize the traceable wrapper. |
| | |
| | Args: |
| | predictor: The SHARP RGBGaussianPredictor model. |
| | """ |
| | super().__init__() |
| | |
| | self.init_model = predictor.init_model |
| | self.feature_model = predictor.feature_model |
| | self.monodepth_model = predictor.monodepth_model |
| | self.prediction_head = predictor.prediction_head |
| | self.gaussian_composer = predictor.gaussian_composer |
| | self.depth_alignment = predictor.depth_alignment |
| | |
| | |
| | self.last_global_scale = None |
| | self.last_monodepth_min = None |
| |
|
| | def forward( |
| | self, |
| | image: torch.Tensor, |
| | disparity_factor: torch.Tensor |
| | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """Run inference with traceable forward pass. |
| | |
| | Args: |
| | image: Input image tensor of shape (1, 3, H, W) in range [0, 1]. |
| | disparity_factor: Disparity factor tensor of shape (1,). |
| | |
| | Returns: |
| | Tuple of 5 tensors representing 3D Gaussians. |
| | """ |
| | |
| | monodepth_output = self.monodepth_model(image) |
| | monodepth_disparity = monodepth_output.disparity |
| |
|
| | |
| | |
| | disparity_factor_expanded = disparity_factor[:, None, None, None] |
| | |
| | |
| | disparity_clamped = monodepth_disparity.clamp(min=1e-4, max=1e4) |
| | monodepth = disparity_factor_expanded / disparity_clamped |
| |
|
| | |
| | monodepth, _ = self.depth_alignment(monodepth, None, monodepth_output.decoder_features) |
| |
|
| | |
| | if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
| | self.last_monodepth_min = monodepth.flatten().min().item() |
| |
|
| | |
| | init_output = self.init_model(image, monodepth) |
| | |
| | |
| | if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
| | if init_output.global_scale is not None: |
| | self.last_global_scale = init_output.global_scale.item() |
| |
|
| | |
| | image_features = self.feature_model( |
| | init_output.feature_input, |
| | encodings=monodepth_output.output_features |
| | ) |
| |
|
| | |
| | delta_values = self.prediction_head(image_features) |
| |
|
| | |
| | gaussians = self.gaussian_composer( |
| | delta=delta_values, |
| | base_values=init_output.gaussian_base_values, |
| | global_scale=init_output.global_scale, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | quaternions = gaussians.quaternions |
| |
|
| | |
| | |
| | quat_norm_sq = torch.sum(quaternions * quaternions, dim=-1, keepdim=True) |
| | quat_norm = torch.sqrt(torch.clamp(quat_norm_sq, min=1e-12)) |
| | quaternions_normalized = quaternions / quat_norm |
| |
|
| | |
| | |
| | abs_quat = torch.abs(quaternions_normalized) |
| | max_idx = torch.argmax(abs_quat, dim=-1, keepdim=True) |
| |
|
| | |
| | one_hot = torch.zeros_like(quaternions_normalized) |
| | one_hot.scatter_(-1, max_idx, 1.0) |
| |
|
| | |
| | max_component_sign = torch.sum(quaternions_normalized * one_hot, dim=-1, keepdim=True) |
| |
|
| | |
| | |
| | quaternions = torch.where(max_component_sign < 0, -quaternions_normalized, quaternions_normalized).float() |
| |
|
| | return ( |
| | gaussians.mean_vectors, |
| | gaussians.singular_values, |
| | quaternions, |
| | gaussians.colors, |
| | gaussians.opacities, |
| | ) |
| |
|
| |
|
| | def load_sharp_model(checkpoint_path: Path | None = None) -> RGBGaussianPredictor: |
| | """Load SHARP model from checkpoint. |
| | |
| | Args: |
| | checkpoint_path: Path to the .pt checkpoint file. |
| | If None, downloads the default model. |
| | |
| | Returns: |
| | The loaded RGBGaussianPredictor model in eval mode. |
| | """ |
| | if checkpoint_path is None: |
| | LOGGER.info("Downloading default model from %s", DEFAULT_MODEL_URL) |
| | state_dict = torch.hub.load_state_dict_from_url(DEFAULT_MODEL_URL, progress=True) |
| | else: |
| | LOGGER.info("Loading checkpoint from %s", checkpoint_path) |
| | state_dict = torch.load(checkpoint_path, weights_only=True, map_location="cpu") |
| |
|
| | |
| | predictor = create_predictor(PredictorParams()) |
| | predictor.load_state_dict(state_dict) |
| | predictor.eval() |
| |
|
| | return predictor |
| |
|
| |
|
| | def convert_to_coreml( |
| | predictor: RGBGaussianPredictor, |
| | output_path: Path, |
| | input_shape: tuple[int, int] = (1536, 1536), |
| | compute_precision: ct.precision = ct.precision.FLOAT16, |
| | compute_units: ct.ComputeUnit = ct.ComputeUnit.ALL, |
| | minimum_deployment_target: ct.target | None = None, |
| | ) -> ct.models.MLModel: |
| | """Convert SHARP model to Core ML format. |
| | |
| | Args: |
| | predictor: The SHARP RGBGaussianPredictor model. |
| | output_path: Path to save the .mlmodel file. |
| | input_shape: Input image shape (height, width). Default is (1536, 1536). |
| | compute_precision: Precision for compute (FLOAT16 or FLOAT32). |
| | compute_units: Target compute units (ALL, CPU_AND_GPU, CPU_ONLY, etc.). |
| | minimum_deployment_target: Minimum iOS/macOS deployment target. |
| | |
| | Returns: |
| | The converted Core ML model. |
| | """ |
| | LOGGER.info("Preparing model for Core ML conversion...") |
| |
|
| | |
| | predictor.depth_alignment.scale_map_estimator = None |
| |
|
| | |
| | model_wrapper = SharpModelTraceable(predictor) |
| | model_wrapper.eval() |
| |
|
| | |
| | LOGGER.info("Pre-warming model for better tracing...") |
| | with torch.no_grad(): |
| | for _ in range(3): |
| | warm_image = torch.randn(1, 3, input_shape[0], input_shape[1]) |
| | warm_disparity = torch.tensor([1.0]) |
| | _ = model_wrapper(warm_image, warm_disparity) |
| |
|
| | |
| | height, width = input_shape |
| | torch.manual_seed(42) |
| | example_image = torch.randn(1, 3, height, width) |
| | example_disparity_factor = torch.tensor([1.0]) |
| |
|
| | LOGGER.info("Attempting torch.jit.script for better tracing...") |
| | try: |
| | with torch.no_grad(): |
| | scripted_model = torch.jit.script(model_wrapper) |
| | LOGGER.info("torch.jit.script succeeded, using scripted model") |
| | traced_model = scripted_model |
| | except Exception as e: |
| | LOGGER.warning(f"torch.jit.script failed: {e}") |
| | LOGGER.info("Falling back to torch.jit.trace...") |
| | with torch.no_grad(): |
| | traced_model = torch.jit.trace( |
| | model_wrapper, |
| | (example_image, example_disparity_factor), |
| | strict=False, |
| | check_trace=False, |
| | ) |
| |
|
| | LOGGER.info("Converting traced model to Core ML...") |
| |
|
| | |
| | inputs = [ |
| | ct.TensorType( |
| | name="image", |
| | shape=(1, 3, height, width), |
| | dtype=np.float32, |
| | ), |
| | ct.TensorType( |
| | name="disparity_factor", |
| | shape=(1,), |
| | dtype=np.float32, |
| | ), |
| | ] |
| |
|
| | |
| | output_names = [ |
| | "mean_vectors_3d_positions", |
| | "singular_values_scales", |
| | "quaternions_rotations", |
| | "colors_rgb_linear", |
| | "opacities_alpha_channel", |
| | ] |
| |
|
| | |
| | outputs = [ |
| | ct.TensorType(name=output_names[0], dtype=np.float32), |
| | ct.TensorType(name=output_names[1], dtype=np.float32), |
| | ct.TensorType(name=output_names[2], dtype=np.float32), |
| | ct.TensorType(name=output_names[3], dtype=np.float32), |
| | ct.TensorType(name=output_names[4], dtype=np.float32), |
| | ] |
| |
|
| | |
| | conversion_kwargs: dict[str, Any] = { |
| | "inputs": inputs, |
| | "outputs": outputs, |
| | "convert_to": "mlprogram", |
| | "compute_precision": compute_precision, |
| | "compute_units": compute_units, |
| | } |
| |
|
| | if minimum_deployment_target is not None: |
| | conversion_kwargs["minimum_deployment_target"] = minimum_deployment_target |
| |
|
| | |
| | mlmodel = ct.convert( |
| | traced_model, |
| | **conversion_kwargs, |
| | ) |
| |
|
| | |
| | mlmodel.author = "Apple Inc." |
| | mlmodel.license = "See LICENSE_MODEL in ml-sharp repository" |
| | mlmodel.short_description = ( |
| | "SHARP: Sharp Monocular View Synthesis - Predicts 3D Gaussian splats from a single image" |
| | ) |
| | mlmodel.version = "1.0.0" |
| |
|
| | |
| | spec = mlmodel.get_spec() |
| |
|
| | |
| | input_descriptions = { |
| | "image": "RGB image normalized to [0, 1], shape (1, 3, H, W)", |
| | "disparity_factor": "Focal length / image width ratio, shape (1,)", |
| | } |
| |
|
| | |
| | output_descriptions = { |
| | "mean_vectors_3d_positions": ( |
| | "3D positions of Gaussian splats in normalized device coordinates (NDC). " |
| | "Shape: (1, N, 3), where N is the number of Gaussians." |
| | ), |
| | "singular_values_scales": ( |
| | "Scale factors for each Gaussian along its principal axes. " |
| | "Represents size and anisotropy. Shape: (1, N, 3)." |
| | ), |
| | "quaternions_rotations": ( |
| | "Rotation of each Gaussian as a unit quaternion [w, x, y, z]. " |
| | "Used to orient the ellipsoid. Shape: (1, N, 4)." |
| | ), |
| | "colors_rgb_linear": ( |
| | "RGB color values in linear RGB space (not gamma-corrected). " |
| | "Shape: (1, N, 3), with range [0, 1]." |
| | ), |
| | "opacities_alpha_channel": ( |
| | "Opacity value per Gaussian (alpha channel), used for blending. " |
| | "Shape: (1, N), where values are in [0, 1]." |
| | ), |
| | } |
| |
|
| | |
| | for i, name in enumerate(output_names): |
| | if i < len(spec.description.output): |
| | output = spec.description.output[i] |
| | output.name = name |
| | output.shortDescription = output_descriptions[name] |
| |
|
| | |
| | LOGGER.info("Output names after update: %s", [o.name for o in spec.description.output]) |
| |
|
| | |
| | LOGGER.info("Saving Core ML model to %s", output_path) |
| | mlmodel.save(str(output_path)) |
| |
|
| | return mlmodel |
| |
|
| |
|
| | class QuaternionValidator: |
| | """Validator for quaternion comparisons with configurable tolerances and outlier analysis.""" |
| |
|
| | DEFAULT_ANGULAR_TOLERANCES = { |
| | "mean": 0.01, |
| | "p99": 0.5, |
| | "p99_9": 2.0, |
| | "max": 15.0, |
| | } |
| |
|
| | def __init__( |
| | self, |
| | angular_tolerances: dict[str, float] | None = None, |
| | enable_outlier_analysis: bool = True, |
| | outlier_thresholds: list[float] | None = None, |
| | ): |
| | """Initialize validator with tolerances. |
| | |
| | Args: |
| | angular_tolerances: Dict with keys 'mean', 'p99', 'p99_9', 'max' for angular diffs in degrees. |
| | enable_outlier_analysis: Whether to perform detailed outlier analysis. |
| | outlier_thresholds: List of angle thresholds for outlier counting. |
| | """ |
| | self.angular_tolerances = angular_tolerances or self.DEFAULT_ANGULAR_TOLERANCES.copy() |
| | self.enable_outlier_analysis = enable_outlier_analysis |
| | self.outlier_thresholds = outlier_thresholds or [5.0, 10.0, 15.0] |
| |
|
| | @staticmethod |
| | def canonicalize_quaternion(q: np.ndarray) -> np.ndarray: |
| | """Canonicalize quaternion to ensure consistent representation. |
| | |
| | Ensures the quaternion with the largest absolute component is positive. |
| | This handles the sign ambiguity where q and -q represent the same rotation. |
| | |
| | Args: |
| | q: Quaternion array of shape (..., 4) |
| | |
| | Returns: |
| | Canonicalized quaternion array. |
| | """ |
| | abs_q = np.abs(q) |
| | max_component_idx = np.argmax(abs_q, axis=-1, keepdims=True) |
| | selector = np.zeros_like(q) |
| | np.put_along_axis(selector, max_component_idx, 1.0, axis=-1) |
| | max_component_sign = np.sum(q * selector, axis=-1, keepdims=True) |
| | return np.where(max_component_sign < 0, -q, q) |
| |
|
| | @staticmethod |
| | def compute_angular_differences( |
| | quats1: np.ndarray, quats2: np.ndarray |
| | ) -> tuple[np.ndarray, dict[str, float]]: |
| | """Compute angular differences between two sets of quaternions. |
| | |
| | Args: |
| | quats1: First set of quaternions shape (N, 4) |
| | quats2: Second set of quaternions shape (N, 4) |
| | |
| | Returns: |
| | Tuple of (angular_differences in degrees, statistics dict) |
| | """ |
| | |
| | norm1 = np.linalg.norm(quats1, axis=-1, keepdims=True) |
| | norm2 = np.linalg.norm(quats2, axis=-1, keepdims=True) |
| | quats1_norm = quats1 / np.clip(norm1, 1e-12, None) |
| | quats2_norm = quats2 / np.clip(norm2, 1e-12, None) |
| |
|
| | |
| | quats1_canon = QuaternionValidator.canonicalize_quaternion(quats1_norm) |
| | quats2_canon = QuaternionValidator.canonicalize_quaternion(quats2_norm) |
| |
|
| | |
| | dot_products = np.sum(quats1_canon * quats2_canon, axis=-1) |
| | dot_products_flipped = np.sum(quats1_canon * (-quats2_canon), axis=-1) |
| |
|
| | |
| | dot_products = np.maximum(np.abs(dot_products), np.abs(dot_products_flipped)) |
| | dot_products = np.clip(dot_products, 0.0, 1.0) |
| |
|
| | |
| | angular_diff_rad = 2.0 * np.arccos(dot_products) |
| | angular_diff_deg = np.degrees(angular_diff_rad) |
| |
|
| | |
| | stats = { |
| | "mean": float(np.mean(angular_diff_deg)), |
| | "std": float(np.std(angular_diff_deg)), |
| | "min": float(np.min(angular_diff_deg)), |
| | "max": float(np.max(angular_diff_deg)), |
| | "p50": float(np.percentile(angular_diff_deg, 50)), |
| | "p90": float(np.percentile(angular_diff_deg, 90)), |
| | "p99": float(np.percentile(angular_diff_deg, 99)), |
| | "p99_9": float(np.percentile(angular_diff_deg, 99.9)), |
| | } |
| |
|
| | return angular_diff_deg, stats |
| |
|
| | def analyze_outliers( |
| | self, angular_diff_deg: np.ndarray |
| | ) -> dict[str, dict[str, int | float]]: |
| | """Analyze outliers in angular differences. |
| | |
| | Args: |
| | angular_diff_deg: Array of angular differences in degrees. |
| | |
| | Returns: |
| | Dict with outlier statistics for each threshold. |
| | """ |
| | if not self.enable_outlier_analysis: |
| | return {} |
| |
|
| | outlier_stats = {} |
| | total = len(angular_diff_deg) |
| |
|
| | for threshold in self.outlier_thresholds: |
| | count = int(np.sum(angular_diff_deg > threshold)) |
| | outlier_stats[f">{threshold}°"] = { |
| | "count": count, |
| | "percentage": (count / total) * 100.0 if total > 0 else 0.0, |
| | } |
| |
|
| | return outlier_stats |
| |
|
| | def validate( |
| | self, |
| | pt_quaternions: np.ndarray, |
| | coreml_quaternions: np.ndarray, |
| | image_name: str = "Unknown", |
| | ) -> dict: |
| | """Validate Core ML quaternions against PyTorch quaternions. |
| | |
| | Args: |
| | pt_quaternions: PyTorch quaternion outputs. |
| | coreml_quaternions: Core ML quaternion outputs. |
| | image_name: Name of the image being validated. |
| | |
| | Returns: |
| | Dict with validation results including status, stats, and outliers. |
| | """ |
| | angular_diff_deg, stats = self.compute_angular_differences( |
| | pt_quaternions, coreml_quaternions |
| | ) |
| | outlier_stats = self.analyze_outliers(angular_diff_deg) |
| |
|
| | |
| | passed = True |
| | failure_reasons = [] |
| |
|
| | for key, tolerance in self.angular_tolerances.items(): |
| | if key in stats and stats[key] > tolerance: |
| | passed = False |
| | failure_reasons.append( |
| | f"{key} angular {stats[key]:.4f}° > tolerance {tolerance:.4f}°" |
| | ) |
| |
|
| | return { |
| | "image": image_name, |
| | "passed": passed, |
| | "failure_reasons": failure_reasons, |
| | "stats": stats, |
| | "outliers": outlier_stats, |
| | "num_gaussians": len(angular_diff_deg), |
| | } |
| |
|
| |
|
| | def find_coreml_output_key(name: str, coreml_outputs: dict) -> str: |
| | """Find matching Core ML output key for a given output name. |
| | |
| | Args: |
| | name: The expected output name |
| | coreml_outputs: Dictionary of Core ML outputs |
| | |
| | Returns: |
| | The matching key from coreml_outputs |
| | """ |
| | if name in coreml_outputs: |
| | return name |
| | |
| | |
| | for key in coreml_outputs: |
| | base_name = name.split('_')[0] |
| | if base_name in key.lower(): |
| | return key |
| | |
| | |
| | output_index = OUTPUT_NAMES.index(name) if name in OUTPUT_NAMES else 0 |
| | return list(coreml_outputs.keys())[output_index] |
| |
|
| |
|
| | def run_inference_pair( |
| | pytorch_model: RGBGaussianPredictor, |
| | mlmodel: ct.models.MLModel, |
| | image_tensor: torch.Tensor, |
| | disparity_factor: float = 1.0, |
| | log_internals: bool = False, |
| | ) -> tuple[list[np.ndarray], dict[str, np.ndarray]]: |
| | """Run inference on both PyTorch and Core ML models. |
| | |
| | Args: |
| | pytorch_model: The PyTorch model |
| | mlmodel: The Core ML model |
| | image_tensor: Input image tensor |
| | disparity_factor: Disparity factor value |
| | log_internals: Whether to log internal values for debugging |
| | |
| | Returns: |
| | Tuple of (pytorch_outputs, coreml_outputs) |
| | """ |
| | |
| | traceable_wrapper = SharpModelTraceable(pytorch_model) |
| | traceable_wrapper.eval() |
| | |
| | |
| | image_tensor = image_tensor.float() |
| | |
| | test_disparity_pt = torch.tensor([disparity_factor], dtype=torch.float32) |
| | with torch.no_grad(): |
| | pt_outputs = traceable_wrapper(image_tensor, test_disparity_pt) |
| | |
| | |
| | if log_internals: |
| | if hasattr(traceable_wrapper, 'last_global_scale') and traceable_wrapper.last_global_scale is not None: |
| | LOGGER.info(f"PyTorch global_scale: {traceable_wrapper.last_global_scale:.6f}") |
| | if hasattr(traceable_wrapper, 'last_monodepth_min') and traceable_wrapper.last_monodepth_min is not None: |
| | LOGGER.info(f"PyTorch monodepth_min: {traceable_wrapper.last_monodepth_min:.6f}") |
| | |
| | |
| | pt_outputs_np = [o.numpy() for o in pt_outputs] |
| | |
| | |
| | test_image_np = image_tensor.numpy() |
| | test_disparity_np = np.array([disparity_factor], dtype=np.float32) |
| | coreml_inputs = { |
| | "image": test_image_np, |
| | "disparity_factor": test_disparity_np, |
| | } |
| | coreml_outputs = mlmodel.predict(coreml_inputs) |
| | |
| | return pt_outputs_np, coreml_outputs |
| |
|
| |
|
| | def compare_outputs( |
| | pt_outputs: list[np.ndarray], |
| | coreml_outputs: dict[str, np.ndarray], |
| | tolerances: dict[str, float], |
| | quat_validator: QuaternionValidator, |
| | image_name: str = "Unknown", |
| | ) -> list[dict]: |
| | """Compare PyTorch and Core ML outputs. |
| | |
| | Args: |
| | pt_outputs: List of PyTorch outputs |
| | coreml_outputs: Dictionary of Core ML outputs |
| | tolerances: Tolerance values per output type |
| | quat_validator: QuaternionValidator instance |
| | image_name: Name of the image being validated |
| | |
| | Returns: |
| | List of validation result dictionaries |
| | """ |
| | validation_results = [] |
| | |
| | for i, name in enumerate(OUTPUT_NAMES): |
| | pt_output = pt_outputs[i] |
| | coreml_key = find_coreml_output_key(name, coreml_outputs) |
| | coreml_output = coreml_outputs[coreml_key] |
| | |
| | result = {"output": name, "passed": True, "failure_reason": ""} |
| | |
| | if name == "quaternions_rotations": |
| | |
| | quat_result = quat_validator.validate(pt_output, coreml_output, image_name=image_name) |
| | |
| | result.update({ |
| | "max_diff": f"{quat_result['stats']['max']:.6f}", |
| | "mean_diff": f"{quat_result['stats']['mean']:.6f}", |
| | "p99_diff": f"{quat_result['stats']['p99']:.6f}", |
| | "passed": quat_result["passed"], |
| | "failure_reason": "; ".join(quat_result["failure_reasons"]) if quat_result["failure_reasons"] else "", |
| | }) |
| | else: |
| | |
| | diff = np.abs(pt_output - coreml_output) |
| | output_tolerance = tolerances.get(name, 0.01) |
| | max_diff = np.max(diff) |
| | |
| | result.update({ |
| | "max_diff": f"{max_diff:.6f}", |
| | "mean_diff": f"{np.mean(diff):.6f}", |
| | "p99_diff": f"{np.percentile(diff, 99):.6f}", |
| | }) |
| | |
| | if max_diff > output_tolerance: |
| | result["passed"] = False |
| | result["failure_reason"] = f"max diff {max_diff:.6f} > tolerance {output_tolerance:.6f}" |
| | |
| | validation_results.append(result) |
| | |
| | return validation_results |
| |
|
| |
|
| | def format_validation_table( |
| | validation_results: list[dict], |
| | image_name: str, |
| | include_image_column: bool = False, |
| | ) -> str: |
| | """Format validation results as a markdown table. |
| | |
| | Args: |
| | validation_results: List of validation result dicts with keys: |
| | output, max_diff, mean_diff, p99_diff, passed, etc. |
| | image_name: Name of the image being validated. |
| | include_image_column: Whether to include the image name as a column. |
| | |
| | Returns: |
| | Formatted markdown table as a string. |
| | """ |
| | lines = [] |
| | |
| | if include_image_column: |
| | lines.append("| Image | Output | Max Diff | Mean Diff | P99 Diff | Status |") |
| | lines.append("|-------|--------|----------|-----------|----------|--------|") |
| | |
| | for result in validation_results: |
| | output_name = result["output"].replace("_", " ").title() |
| | status = "✅ PASS" if result["passed"] else "❌ FAIL" |
| | lines.append( |
| | f"| {image_name} | {output_name} | {result['max_diff']} | " |
| | f"{result['mean_diff']} | {result['p99_diff']} | {status} |" |
| | ) |
| | else: |
| | lines.append("| Output | Max Diff | Mean Diff | P99 Diff | Status |") |
| | lines.append("|--------|----------|-----------|----------|--------|") |
| | |
| | for result in validation_results: |
| | output_name = result["output"].replace("_", " ").title() |
| | status = "✅ PASS" if result["passed"] else "❌ FAIL" |
| | lines.append( |
| | f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | " |
| | f"{result['p99_diff']} | {status} |" |
| | ) |
| | |
| | return "\n".join(lines) |
| |
|
| |
|
| | def validate_coreml_model( |
| | mlmodel: ct.models.MLModel, |
| | pytorch_model: RGBGaussianPredictor, |
| | input_shape: tuple[int, int] = (1536, 1536), |
| | tolerance: float = 0.01, |
| | angular_tolerances: dict[str, float] | None = None, |
| | ) -> bool: |
| | """Validate Core ML model outputs against PyTorch model. |
| | |
| | Args: |
| | mlmodel: The Core ML model to validate. |
| | pytorch_model: The original PyTorch model. |
| | input_shape: Input image shape (height, width). |
| | tolerance: Maximum allowed difference between outputs. |
| | angular_tolerances: Dict with keys 'mean', 'p99', 'p99_9', 'max' for angular diffs in degrees. |
| | |
| | Returns: |
| | True if validation passes, False otherwise. |
| | """ |
| | LOGGER.info("Validating Core ML model against PyTorch...") |
| |
|
| | height, width = input_shape |
| |
|
| | |
| | np.random.seed(42) |
| | torch.manual_seed(42) |
| |
|
| | |
| | test_image_np = np.random.rand(1, 3, height, width).astype(np.float32) |
| | test_disparity = np.array([1.0], dtype=np.float32) |
| |
|
| | |
| | test_image_pt = torch.from_numpy(test_image_np) |
| | test_disparity_pt = torch.from_numpy(test_disparity) |
| |
|
| | traceable_wrapper = SharpModelTraceable(pytorch_model) |
| | traceable_wrapper.eval() |
| |
|
| | with torch.no_grad(): |
| | pt_outputs = traceable_wrapper(test_image_pt, test_disparity_pt) |
| |
|
| | |
| | coreml_inputs = { |
| | "image": test_image_np, |
| | "disparity_factor": test_disparity, |
| | } |
| | coreml_outputs = mlmodel.predict(coreml_inputs) |
| |
|
| | LOGGER.info(f"PyTorch outputs shapes: {[o.shape for o in pt_outputs]}") |
| | LOGGER.info(f"Core ML outputs keys: {list(coreml_outputs.keys())}") |
| |
|
| | |
| | output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"] |
| |
|
| | |
| | tolerances = { |
| | "mean_vectors_3d_positions": 0.001, |
| | "singular_values_scales": 0.0001, |
| | "quaternions_rotations": 2.0, |
| | "colors_rgb_linear": 0.002, |
| | "opacities_alpha_channel": 0.005, |
| | } |
| |
|
| | |
| | if angular_tolerances is None: |
| | angular_tolerances = { |
| | "mean": 0.01, |
| | "p99": 0.1, |
| | "p99_9": 1.0, |
| | "max": 5.0, |
| | } |
| |
|
| | |
| | quat_validator = QuaternionValidator(angular_tolerances=angular_tolerances) |
| |
|
| | all_passed = True |
| |
|
| | |
| | LOGGER.info("=== Depth/Position Statistics ===") |
| | pt_positions = pt_outputs[0].numpy() |
| | coreml_key = [k for k in coreml_outputs.keys() if "mean_vectors" in k][0] |
| | coreml_positions = coreml_outputs[coreml_key] |
| |
|
| | LOGGER.info(f"PyTorch positions - Z range: [{pt_positions[..., 2].min():.4f}, {pt_positions[..., 2].max():.4f}], mean: {pt_positions[..., 2].mean():.4f}, std: {pt_positions[..., 2].std():.4f}") |
| | LOGGER.info(f"CoreML positions - Z range: [{coreml_positions[..., 2].min():.4f}, {coreml_positions[..., 2].max():.4f}], mean: {coreml_positions[..., 2].mean():.4f}, std: {coreml_positions[..., 2].std():.4f}") |
| |
|
| | z_diff = np.abs(pt_positions[..., 2] - coreml_positions[..., 2]) |
| | LOGGER.info(f"Z-coordinate difference - max: {z_diff.max():.6f}, mean: {z_diff.mean():.6f}, std: {z_diff.std():.6f}") |
| | LOGGER.info("=================================") |
| |
|
| | |
| | validation_results = [] |
| |
|
| | for i, name in enumerate(output_names): |
| | pt_output = pt_outputs[i].numpy() |
| |
|
| | |
| | coreml_key = None |
| | if name in coreml_outputs: |
| | coreml_key = name |
| | else: |
| | |
| | for key in coreml_outputs: |
| | base_name = name.split('_')[0] |
| | if base_name in key.lower(): |
| | coreml_key = key |
| | break |
| | if coreml_key is None: |
| | coreml_key = list(coreml_outputs.keys())[i] |
| |
|
| | coreml_output = coreml_outputs[coreml_key] |
| | result = {"output": name, "passed": True, "failure_reason": ""} |
| |
|
| | |
| | if name == "quaternions_rotations": |
| | |
| | quat_result = quat_validator.validate(pt_output, coreml_output, image_name="Random") |
| | |
| | result.update({ |
| | "max_diff": f"{quat_result['stats']['max']:.6f}", |
| | "mean_diff": f"{quat_result['stats']['mean']:.6f}", |
| | "p99_diff": f"{quat_result['stats']['p99']:.6f}", |
| | "p99_9_diff": f"{quat_result['stats']['p99_9']:.6f}", |
| | "max_angular": f"{quat_result['stats']['max']:.4f}", |
| | "mean_angular": f"{quat_result['stats']['mean']:.4f}", |
| | "p99_angular": f"{quat_result['stats']['p99']:.4f}", |
| | "passed": quat_result["passed"], |
| | "failure_reason": "; ".join(quat_result["failure_reasons"]) if quat_result["failure_reasons"] else "", |
| | "quat_stats": quat_result["stats"], |
| | "outliers": quat_result["outliers"], |
| | }) |
| | if not quat_result["passed"]: |
| | all_passed = False |
| | else: |
| | diff = np.abs(pt_output - coreml_output) |
| | output_tolerance = tolerances.get(name, tolerance) |
| | result.update({ |
| | "max_diff": f"{np.max(diff):.6f}", |
| | "mean_diff": f"{np.mean(diff):.6f}", |
| | "p99_diff": f"{np.percentile(diff, 99):.6f}", |
| | "tolerance": f"{output_tolerance:.6f}" |
| | }) |
| | if np.max(diff) > output_tolerance: |
| | result["passed"] = False |
| | result["failure_reason"] = f"max diff {np.max(diff):.6f} > tolerance {output_tolerance:.6f}" |
| | all_passed = False |
| |
|
| | validation_results.append(result) |
| |
|
| | |
| | LOGGER.info("\n### Validation Results\n") |
| | LOGGER.info("| Output | Max Diff | Mean Diff | P99 Diff | P99.9 Diff | Angular Diff (°) | Status |") |
| | LOGGER.info("|--------|----------|-----------|----------|------------|------------------|--------|") |
| |
|
| | for result in validation_results: |
| | output_name = result["output"].replace("_", " ").title() |
| | if "max_angular" in result: |
| | angular_info = f"{result['max_angular']} / {result['mean_angular']} / {result['p99_angular']}" |
| | p99_9 = result.get("p99_9_diff", "-") |
| | status = "✅ PASS" if result["passed"] else f"❌ FAIL" |
| | LOGGER.info(f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | {result['p99_diff']} | {p99_9} | {angular_info} | {status} |") |
| | else: |
| | status = "✅ PASS" if result["passed"] else f"❌ FAIL" |
| | LOGGER.info(f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | {result['p99_diff']} | - | - | {status} |") |
| | LOGGER.info("") |
| |
|
| | |
| | for result in validation_results: |
| | if "outliers" in result and result["outliers"]: |
| | LOGGER.info("### Quaternion Outlier Analysis\n") |
| | LOGGER.info(f"| Threshold | Count | Percentage |") |
| | LOGGER.info("|-----------|-------|------------|") |
| | for threshold, data in result["outliers"].items(): |
| | LOGGER.info(f"| {threshold} | {data['count']} | {data['percentage']:.4f}% |") |
| | LOGGER.info("") |
| |
|
| | return all_passed |
| |
|
| |
|
| | def load_and_preprocess_image( |
| | image_path: Path, |
| | target_size: tuple[int, int] = (1536, 1536), |
| | ) -> tuple[torch.Tensor, float, tuple[int, int]]: |
| | """Load and preprocess an input image for SHARP inference. |
| | |
| | Args: |
| | image_path: Path to the input image file. |
| | target_size: Target (height, width) for resizing. |
| | |
| | Returns: |
| | Tuple of (preprocessed image tensor, focal_length_px, original_size) |
| | - Preprocessed image tensor of shape (1, 3, H, W) in range [0, 1] |
| | - Focal length in pixels (from EXIF or default) |
| | - Original image size (width, height) |
| | """ |
| | LOGGER.info(f"Loading image from {image_path}") |
| | |
| | |
| | image_np, original_size, f_px = io.load_rgb(image_path) |
| | LOGGER.info(f"Original image size: {original_size}, focal length: {f_px:.2f}px") |
| | |
| | |
| | |
| | image_tensor = torch.from_numpy(image_np).float() / 255.0 |
| | image_tensor = image_tensor.permute(2, 0, 1) |
| | original_height, original_width = image_np.shape[:2] |
| | |
| | |
| | if (original_width, original_height) != (target_size[1], target_size[0]): |
| | LOGGER.info(f"Resizing to {target_size[1]}x{target_size[0]}") |
| | import torch.nn.functional as F |
| | image_tensor = F.interpolate( |
| | image_tensor.unsqueeze(0), |
| | size=(target_size[0], target_size[1]), |
| | mode="bilinear", |
| | align_corners=True, |
| | ).squeeze(0) |
| | |
| | |
| | image_tensor = image_tensor.unsqueeze(0) |
| | |
| | LOGGER.info(f"Preprocessed image shape: {image_tensor.shape}, range: [{image_tensor.min():.4f}, {image_tensor.max():.4f}]") |
| | |
| | return image_tensor, f_px, (original_width, original_height) |
| |
|
| |
|
| | def validate_with_image( |
| | mlmodel: ct.models.MLModel, |
| | pytorch_model: RGBGaussianPredictor, |
| | image_path: Path, |
| | input_shape: tuple[int, int] = (1536, 1536), |
| | ) -> bool: |
| | """Validate Core ML model outputs against PyTorch model using a real input image. |
| | |
| | Args: |
| | mlmodel: The Core ML model to validate. |
| | pytorch_model: The original PyTorch model. |
| | image_path: Path to the input image file. |
| | input_shape: Expected input image shape (height, width). |
| | |
| | Returns: |
| | True if validation passes, False otherwise. |
| | """ |
| | LOGGER.info("=" * 60) |
| | LOGGER.info("Validating Core ML model against PyTorch with real image") |
| | LOGGER.info("=" * 60) |
| | |
| | |
| | test_image = load_and_preprocess_image(image_path, input_shape) |
| | test_disparity = np.array([1.0], dtype=np.float32) |
| | |
| | |
| | traceable_wrapper = SharpModelTraceable(pytorch_model) |
| | traceable_wrapper.eval() |
| | |
| | with torch.no_grad(): |
| | pt_outputs = traceable_wrapper(test_image, torch.from_numpy(test_disparity)) |
| | |
| | LOGGER.info(f"PyTorch outputs shapes: {[o.shape for o in pt_outputs]}") |
| | |
| | |
| | test_image_np = test_image.numpy() |
| | coreml_inputs = { |
| | "image": test_image_np, |
| | "disparity_factor": test_disparity, |
| | } |
| | coreml_outputs = mlmodel.predict(coreml_inputs) |
| | |
| | LOGGER.info(f"Core ML outputs keys: {list(coreml_outputs.keys())}") |
| | |
| | |
| | output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"] |
| | |
| | |
| | |
| | tolerances = { |
| | "mean_vectors_3d_positions": 1.2, |
| | "singular_values_scales": 0.01, |
| | "quaternions_rotations": 5.0, |
| | "colors_rgb_linear": 0.01, |
| | "opacities_alpha_channel": 0.05, |
| | } |
| | |
| | |
| | angular_tolerances = { |
| | "mean": 0.1, |
| | "p99": 1.0, |
| | "max": 15.0, |
| | } |
| | |
| | all_passed = True |
| | |
| | |
| | LOGGER.info(f"\n=== Input Image Statistics ===") |
| | LOGGER.info(f"Image path: {image_path}") |
| | LOGGER.info(f"Image shape: {test_image.shape}") |
| | LOGGER.info(f"Image range: [{test_image.min():.4f}, {test_image.max():.4f}]") |
| | LOGGER.info(f"Image mean: {test_image.mean(dim=[1,2,3]).tolist()}") |
| | LOGGER.info("=" * 30) |
| | |
| | |
| | pt_positions = pt_outputs[0].numpy() |
| | coreml_key = [k for k in coreml_outputs.keys() if "mean_vectors" in k][0] |
| | coreml_positions = coreml_outputs[coreml_key] |
| | |
| | LOGGER.info("\n=== Depth/Position Statistics ===") |
| | LOGGER.info(f"PyTorch positions - Z range: [{pt_positions[..., 2].min():.4f}, {pt_positions[..., 2].max():.4f}], mean: {pt_positions[..., 2].mean():.4f}, std: {pt_positions[..., 2].std():.4f}") |
| | LOGGER.info(f"CoreML positions - Z range: [{coreml_positions[..., 2].min():.4f}, {coreml_positions[..., 2].max():.4f}], mean: {coreml_positions[..., 2].mean():.4f}, std: {coreml_positions[..., 2].std():.4f}") |
| | |
| | z_diff = np.abs(pt_positions[..., 2] - coreml_positions[..., 2]) |
| | LOGGER.info(f"Z-coordinate difference - max: {z_diff.max():.6f}, mean: {z_diff.mean():.6f}, std: {z_diff.std():.6f}") |
| | LOGGER.info("=================================\n") |
| | |
| | |
| | validation_results = [] |
| | |
| | for i, name in enumerate(output_names): |
| | pt_output = pt_outputs[i].numpy() |
| | |
| | |
| | coreml_key = None |
| | if name in coreml_outputs: |
| | coreml_key = name |
| | else: |
| | |
| | for key in coreml_outputs: |
| | base_name = name.split('_')[0] |
| | if base_name in key.lower(): |
| | coreml_key = key |
| | break |
| | if coreml_key is None: |
| | coreml_key = list(coreml_outputs.keys())[i] |
| | |
| | coreml_output = coreml_outputs[coreml_key] |
| | result = {"output": name, "passed": True, "failure_reason": ""} |
| | |
| | |
| | if name == "quaternions_rotations": |
| | pt_quat_norm = np.linalg.norm(pt_output, axis=-1, keepdims=True) |
| | pt_output_normalized = pt_output / np.clip(pt_quat_norm, 1e-12, None) |
| | |
| | coreml_quat_norm = np.linalg.norm(coreml_output, axis=-1, keepdims=True) |
| | coreml_output_normalized = coreml_output / np.clip(coreml_quat_norm, 1e-12, None) |
| | |
| | def canonicalize_quaternion(q): |
| | abs_q = np.abs(q) |
| | max_component_idx = np.argmax(abs_q, axis=-1, keepdims=True) |
| | selector = np.zeros_like(q) |
| | np.put_along_axis(selector, max_component_idx, 1, axis=-1) |
| | max_component_sign = np.sum(q * selector, axis=-1, keepdims=True) |
| | return np.where(max_component_sign < 0, -q, q) |
| | |
| | pt_output_canonical = canonicalize_quaternion(pt_output_normalized) |
| | coreml_output_canonical = canonicalize_quaternion(coreml_output_normalized) |
| | |
| | diff = np.abs(pt_output_canonical - coreml_output_canonical) |
| | dot_products = np.sum(pt_output_canonical * coreml_output_canonical, axis=-1) |
| | dot_products_flipped = np.sum(pt_output_canonical * (-coreml_output_canonical), axis=-1) |
| | |
| | |
| | dot_products = np.where( |
| | np.abs(dot_products) > np.abs(dot_products_flipped), |
| | np.abs(dot_products), |
| | np.abs(dot_products_flipped) |
| | ) |
| | dot_products = np.clip(dot_products, 0.0, 1.0) |
| | angular_diff_rad = 2 * np.arccos(dot_products) |
| | angular_diff_deg = np.degrees(angular_diff_rad) |
| | max_angular = np.max(angular_diff_deg) |
| | mean_angular = np.mean(angular_diff_deg) |
| | p99_angular = np.percentile(angular_diff_deg, 99) |
| | |
| | quat_passed = True |
| | failure_reasons = [] |
| | |
| | if mean_angular > angular_tolerances["mean"]: |
| | quat_passed = False |
| | failure_reasons.append(f"mean angular {mean_angular:.4f}° > {angular_tolerances['mean']:.4f}°") |
| | if p99_angular > angular_tolerances["p99"]: |
| | quat_passed = False |
| | failure_reasons.append(f"p99 angular {p99_angular:.4f}° > {angular_tolerances['p99']:.4f}°") |
| | if max_angular > angular_tolerances["max"]: |
| | quat_passed = False |
| | failure_reasons.append(f"max angular {max_angular:.4f}° > {angular_tolerances['max']:.4f}°") |
| | |
| | result.update({ |
| | "max_diff": f"{np.max(diff):.6f}", |
| | "mean_diff": f"{np.mean(diff):.6f}", |
| | "p99_diff": f"{np.percentile(diff, 99):.6f}", |
| | "max_angular": f"{max_angular:.4f}", |
| | "mean_angular": f"{mean_angular:.4f}", |
| | "p99_angular": f"{p99_angular:.4f}", |
| | "passed": quat_passed, |
| | "failure_reason": "; ".join(failure_reasons) if failure_reasons else "" |
| | }) |
| | if not quat_passed: |
| | all_passed = False |
| | else: |
| | diff = np.abs(pt_output - coreml_output) |
| | output_tolerance = tolerances.get(name, 0.01) |
| | result.update({ |
| | "max_diff": f"{np.max(diff):.6f}", |
| | "mean_diff": f"{np.mean(diff):.6f}", |
| | "p99_diff": f"{np.percentile(diff, 99):.6f}", |
| | "tolerance": f"{output_tolerance:.6f}" |
| | }) |
| | if np.max(diff) > output_tolerance: |
| | result["passed"] = False |
| | result["failure_reason"] = f"max diff {np.max(diff):.6f} > tolerance {output_tolerance:.6f}" |
| | all_passed = False |
| | |
| | validation_results.append(result) |
| | |
| | |
| | LOGGER.info("\n### Image Validation Results\n") |
| | LOGGER.info(f"| Output | Max Diff | Mean Diff | P99 Diff | Angular Diff (°) | Status |") |
| | LOGGER.info(f"|--------|----------|-----------|----------|------------------|--------|") |
| | |
| | for result in validation_results: |
| | output_name = result["output"].replace("_", " ").title() |
| | if "max_angular" in result: |
| | angular_info = f"{result['max_angular']} / {result['mean_angular']} / {result['p99_angular']}" |
| | else: |
| | angular_info = "-" |
| | status = "✅ PASS" if result["passed"] else f"❌ FAIL" |
| | LOGGER.info(f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | {result['p99_diff']} | {angular_info} | {status} |") |
| | LOGGER.info("") |
| | |
| | return all_passed |
| |
|
| |
|
| | def validate_with_image_set( |
| | mlmodel: ct.models.MLModel, |
| | pytorch_model: RGBGaussianPredictor, |
| | image_paths: list[Path], |
| | input_shape: tuple[int, int] = (1536, 1536), |
| | ) -> bool: |
| | """Validate Core ML model against PyTorch using multiple input images. |
| | |
| | Args: |
| | mlmodel: The Core ML model to validate. |
| | pytorch_model: The original PyTorch model. |
| | image_paths: List of paths to input images for validation. |
| | input_shape: Expected input image shape (height, width). |
| | |
| | Returns: |
| | True if all validations pass, False otherwise. |
| | """ |
| | LOGGER.info("=" * 60) |
| | LOGGER.info(f"Validating Core ML model with {len(image_paths)} images") |
| | LOGGER.info("=" * 60) |
| |
|
| | |
| | |
| | angular_tolerances = { |
| | "mean": 0.2, |
| | "p99": 2.0, |
| | "p99_9": 5.0, |
| | "max": 25.0, |
| | } |
| |
|
| | |
| | quat_validator = QuaternionValidator(angular_tolerances=angular_tolerances) |
| |
|
| | all_passed = True |
| | all_validation_results = [] |
| |
|
| | for image_path in image_paths: |
| | if not image_path.exists(): |
| | LOGGER.error(f"Input image not found: {image_path}") |
| | all_passed = False |
| | continue |
| |
|
| | LOGGER.info(f"\n--- Validating with {image_path.name} ---") |
| |
|
| | |
| | image_results = validate_with_single_image_detailed( |
| | mlmodel, pytorch_model, image_path, input_shape, quat_validator |
| | ) |
| | |
| | |
| | for result in image_results: |
| | result["image"] = image_path.name |
| | all_validation_results.append(result) |
| | |
| | |
| | if not all(r["passed"] for r in image_results): |
| | all_passed = False |
| |
|
| | |
| | LOGGER.info("\n" + "=" * 60) |
| | LOGGER.info("### Multi-Image Validation Summary") |
| | LOGGER.info("=" * 60 + "\n") |
| | |
| | |
| | if all_validation_results: |
| | table = format_validation_table(all_validation_results, "", include_image_column=True) |
| | LOGGER.info(table) |
| | LOGGER.info("") |
| |
|
| | return all_passed |
| |
|
| |
|
| | def validate_with_single_image_detailed( |
| | mlmodel: ct.models.MLModel, |
| | pytorch_model: RGBGaussianPredictor, |
| | image_path: Path, |
| | input_shape: tuple[int, int], |
| | quat_validator: QuaternionValidator | None = None, |
| | ) -> list[dict]: |
| | """Validate with a single image and return detailed results. |
| | |
| | Args: |
| | mlmodel: The Core ML model to validate. |
| | pytorch_model: The original PyTorch model. |
| | image_path: Path to the input image file. |
| | input_shape: Expected input image shape. |
| | quat_validator: Optional QuaternionValidator instance. |
| | |
| | Returns: |
| | List of validation result dictionaries. |
| | """ |
| | |
| | test_image, f_px, (orig_width, orig_height) = load_and_preprocess_image(image_path, input_shape) |
| | |
| | |
| | disparity_factor = f_px / orig_width |
| | LOGGER.info(f"Using disparity_factor = {disparity_factor:.6f} (f_px={f_px:.2f} / width={orig_width})") |
| | |
| | |
| | pt_outputs, coreml_outputs = run_inference_pair( |
| | pytorch_model, mlmodel, test_image, |
| | disparity_factor=disparity_factor, |
| | log_internals=True |
| | ) |
| | |
| | |
| | pt_positions = pt_outputs[0] |
| | coreml_key = find_coreml_output_key("mean_vectors_3d_positions", coreml_outputs) |
| | coreml_positions = coreml_outputs[coreml_key] |
| | |
| | |
| | LOGGER.info(f"=== Depth/Position Statistics ({image_path.name}) ===") |
| | LOGGER.info(f"PyTorch positions - Z range: [{pt_positions[..., 2].min():.4f}, {pt_positions[..., 2].max():.4f}], mean: {pt_positions[..., 2].mean():.4f}") |
| | LOGGER.info(f"CoreML positions - Z range: [{coreml_positions[..., 2].min():.4f}, {coreml_positions[..., 2].max():.4f}], mean: {coreml_positions[..., 2].mean():.4f}") |
| | |
| | |
| | pos_diff = np.abs(pt_positions - coreml_positions) |
| | LOGGER.info(f"Position difference (X,Y,Z) - max: [{pos_diff[..., 0].max():.6f}, {pos_diff[..., 1].max():.6f}, {pos_diff[..., 2].max():.6f}]") |
| | LOGGER.info(f"Position difference (X,Y,Z) - mean: [{pos_diff[..., 0].mean():.6f}, {pos_diff[..., 1].mean():.6f}, {pos_diff[..., 2].mean():.6f}]") |
| | |
| | |
| | z_diff = np.abs(pt_positions[..., 2] - coreml_positions[..., 2]) |
| | z_ratio = z_diff / np.clip(pt_positions[..., 2], 1e-6, None) |
| | LOGGER.info(f"Z relative error - mean: {z_ratio.mean()*100:.4f}%, max: {z_ratio.max()*100:.4f}%") |
| | |
| | |
| | pt_scales = pt_outputs[1] |
| | coreml_scales_key = find_coreml_output_key("singular_values_scales", coreml_outputs) |
| | coreml_scales = coreml_outputs[coreml_scales_key] |
| | scales_diff = np.abs(pt_scales - coreml_scales) |
| | scales_ratio = scales_diff / np.clip(pt_scales, 1e-6, None) |
| | LOGGER.info(f"Scales relative error - mean: {scales_ratio.mean()*100:.4f}%, max: {scales_ratio.max()*100:.4f}%") |
| | |
| | |
| | tolerance_config = ToleranceConfig() |
| | tolerances = tolerance_config.image_tolerances |
| | |
| | |
| | if quat_validator is None: |
| | quat_validator = QuaternionValidator( |
| | angular_tolerances=tolerance_config.angular_tolerances_image |
| | ) |
| | |
| | |
| | validation_results = compare_outputs( |
| | pt_outputs, |
| | coreml_outputs, |
| | tolerances, |
| | quat_validator, |
| | image_name=image_path.name |
| | ) |
| | |
| | return validation_results |
| |
|
| |
|
| | def validate_with_single_image( |
| | mlmodel: ct.models.MLModel, |
| | pytorch_model: RGBGaussianPredictor, |
| | image_path: Path, |
| | input_shape: tuple[int, int], |
| | quat_validator: QuaternionValidator | None = None, |
| | ) -> bool: |
| | """Validate with a single image using the new QuaternionValidator. |
| | |
| | Args: |
| | mlmodel: The Core ML model to validate. |
| | pytorch_model: The original PyTorch model. |
| | image_path: Path to the input image file. |
| | input_shape: Expected input image shape. |
| | quat_validator: Optional QuaternionValidator instance. |
| | |
| | Returns: |
| | True if validation passes, False otherwise. |
| | """ |
| | |
| | test_image = load_and_preprocess_image(image_path, input_shape) |
| | test_disparity = np.array([1.0], dtype=np.float32) |
| |
|
| | |
| | traceable_wrapper = SharpModelTraceable(pytorch_model) |
| | traceable_wrapper.eval() |
| |
|
| | with torch.no_grad(): |
| | pt_outputs = traceable_wrapper(test_image, torch.from_numpy(test_disparity)) |
| |
|
| | |
| | test_image_np = test_image.numpy() |
| | coreml_inputs = { |
| | "image": test_image_np, |
| | "disparity_factor": test_disparity, |
| | } |
| | coreml_outputs = mlmodel.predict(coreml_inputs) |
| |
|
| | |
| | output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"] |
| |
|
| | |
| | tolerances = { |
| | "mean_vectors_3d_positions": 1.2, |
| | "singular_values_scales": 0.01, |
| | "colors_rgb_linear": 0.01, |
| | "opacities_alpha_channel": 0.05, |
| | "quaternions_rotations": 5.0, |
| | } |
| |
|
| | |
| | if quat_validator is None: |
| | quat_validator = QuaternionValidator() |
| |
|
| | |
| | LOGGER.info(f"Image: {image_path.name}, shape: {test_image.shape}, range: [{test_image.min():.4f}, {test_image.max():.4f}]") |
| |
|
| | |
| | all_passed = True |
| | validation_results = [] |
| |
|
| | for i, name in enumerate(output_names): |
| | pt_output = pt_outputs[i].numpy() |
| |
|
| | |
| | coreml_key = None |
| | if name in coreml_outputs: |
| | coreml_key = name |
| | else: |
| | for key in coreml_outputs: |
| | base_name = name.split('_')[0] |
| | if base_name in key.lower(): |
| | coreml_key = key |
| | break |
| | if coreml_key is None: |
| | coreml_key = list(coreml_outputs.keys())[i] |
| |
|
| | coreml_output = coreml_outputs[coreml_key] |
| | result = {"output": name, "passed": True, "failure_reason": ""} |
| |
|
| | if name == "quaternions_rotations": |
| | |
| | quat_result = quat_validator.validate(pt_output, coreml_output, image_name=image_path.name) |
| |
|
| | result.update({ |
| | "max_diff": f"{quat_result['stats']['max']:.6f}", |
| | "mean_diff": f"{quat_result['stats']['mean']:.6f}", |
| | "p99_diff": f"{quat_result['stats']['p99']:.6f}", |
| | "passed": quat_result["passed"], |
| | "failure_reason": "; ".join(quat_result["failure_reasons"]) if quat_result["failure_reasons"] else "", |
| | }) |
| |
|
| | if not quat_result["passed"]: |
| | all_passed = False |
| | else: |
| | diff = np.abs(pt_output - coreml_output) |
| | output_tolerance = tolerances.get(name, 0.01) |
| | max_diff = np.max(diff) |
| |
|
| | result.update({ |
| | "max_diff": f"{max_diff:.6f}", |
| | "mean_diff": f"{np.mean(diff):.6f}", |
| | "p99_diff": f"{np.percentile(diff, 99):.6f}", |
| | }) |
| |
|
| | if max_diff > output_tolerance: |
| | result["passed"] = False |
| | result["failure_reason"] = f"max diff {max_diff:.6f} > tolerance {output_tolerance:.6f}" |
| | all_passed = False |
| |
|
| | validation_results.append(result) |
| |
|
| | |
| | LOGGER.info(f"\n### Validation Results: {image_path.name}\n") |
| | table = format_validation_table(validation_results, image_path.name, include_image_column=False) |
| | LOGGER.info(table) |
| | LOGGER.info("") |
| |
|
| | return all_passed |
| |
|
| |
|
| | def main(): |
| | """Main conversion script.""" |
| | parser = argparse.ArgumentParser( |
| | description="Convert SHARP PyTorch model to Core ML format" |
| | ) |
| | parser.add_argument( |
| | "-c", "--checkpoint", |
| | type=Path, |
| | default=None, |
| | help="Path to PyTorch checkpoint. Downloads default if not provided.", |
| | ) |
| | parser.add_argument( |
| | "-o", "--output", |
| | type=Path, |
| | default=Path("sharp.mlpackage"), |
| | help="Output path for Core ML model (default: sharp.mlpackage)", |
| | ) |
| | parser.add_argument( |
| | "--height", |
| | type=int, |
| | default=1536, |
| | help="Input image height (default: 1536)", |
| | ) |
| | parser.add_argument( |
| | "--width", |
| | type=int, |
| | default=1536, |
| | help="Input image width (default: 1536)", |
| | ) |
| | parser.add_argument( |
| | "--precision", |
| | choices=["float16", "float32"], |
| | default="float32", |
| | help="Compute precision (default: float32)", |
| | ) |
| | parser.add_argument( |
| | "--validate", |
| | action="store_true", |
| | help="Validate Core ML model against PyTorch", |
| | ) |
| | parser.add_argument( |
| | "-v", "--verbose", |
| | action="store_true", |
| | help="Enable verbose logging", |
| | ) |
| | parser.add_argument( |
| | "--input-image", |
| | type=Path, |
| | default=None, |
| | action="append", |
| | help="Path to input image for validation (can be specified multiple times, requires --validate)", |
| | ) |
| | parser.add_argument( |
| | "--tolerance-mean", |
| | type=float, |
| | default=None, |
| | help="Custom mean angular tolerance in degrees (default: 0.01 for random, 0.1 for images)", |
| | ) |
| | parser.add_argument( |
| | "--tolerance-p99", |
| | type=float, |
| | default=None, |
| | help="Custom P99 angular tolerance in degrees (default: 0.5 for random, 1.0 for images)", |
| | ) |
| | parser.add_argument( |
| | "--tolerance-max", |
| | type=float, |
| | default=None, |
| | help="Custom max angular tolerance in degrees (default: 15.0)", |
| | ) |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.DEBUG if args.verbose else logging.INFO, |
| | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
| | ) |
| |
|
| | |
| | LOGGER.info("Loading SHARP model...") |
| | predictor = load_sharp_model(args.checkpoint) |
| |
|
| | |
| | input_shape = (args.height, args.width) |
| | precision = ct.precision.FLOAT16 if args.precision == "float16" else ct.precision.FLOAT32 |
| |
|
| | |
| | LOGGER.info("Converting using direct tracing...") |
| | mlmodel = convert_to_coreml( |
| | predictor, |
| | args.output, |
| | input_shape=input_shape, |
| | compute_precision=precision, |
| | ) |
| |
|
| | LOGGER.info(f"Core ML model saved to {args.output}") |
| |
|
| | |
| | if args.validate: |
| | if args.input_image: |
| | |
| | validation_passed = validate_with_image_set(mlmodel, predictor, args.input_image, input_shape) |
| | else: |
| | |
| | |
| | angular_tolerances = None |
| | if args.tolerance_mean or args.tolerance_p99 or args.tolerance_max: |
| | angular_tolerances = { |
| | "mean": args.tolerance_mean if args.tolerance_mean else 0.01, |
| | "p99": args.tolerance_p99 if args.tolerance_p99 else 0.5, |
| | "p99_9": 2.0, |
| | "max": args.tolerance_max if args.tolerance_max else 15.0, |
| | } |
| | validation_passed = validate_coreml_model(mlmodel, predictor, input_shape, angular_tolerances=angular_tolerances) |
| |
|
| | if validation_passed: |
| | LOGGER.info("✓ Validation passed!") |
| | else: |
| | LOGGER.error("✗ Validation failed!") |
| | return 1 |
| |
|
| | LOGGER.info("Conversion complete!") |
| | return 0 |
| |
|
| |
|
| | if __name__ == "__main__": |
| | exit(main()) |
| | exit(main()) |
| |
|