Kyle Pearson
Fix dtype precision, improve depth scaling tolerance, add debug logging, update manifest weights, enhance preprocessing output.
027bd3d
| """Convert SHARP PyTorch model to Core ML .mlmodel format. | |
| This script converts the SHARP (Sharp Monocular View Synthesis) model | |
| from PyTorch (.pt) to Core ML (.mlmodel) format for deployment on Apple devices. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import logging | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any | |
| import coremltools as ct | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| # Import SHARP model components | |
| from sharp.models import PredictorParams, create_predictor | |
| from sharp.models.predictor import RGBGaussianPredictor | |
| from sharp.utils import io | |
| LOGGER = logging.getLogger(__name__) | |
| DEFAULT_MODEL_URL = "https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt" | |
| # ============================================================================ | |
| # Constants & Configuration | |
| # ============================================================================ | |
| # Output names for Core ML model | |
| OUTPUT_NAMES = [ | |
| "mean_vectors_3d_positions", | |
| "singular_values_scales", | |
| "quaternions_rotations", | |
| "colors_rgb_linear", | |
| "opacities_alpha_channel", | |
| ] | |
| # Output descriptions for Core ML metadata | |
| OUTPUT_DESCRIPTIONS = { | |
| "mean_vectors_3d_positions": ( | |
| "3D positions of Gaussian splats in normalized device coordinates (NDC). " | |
| "Shape: (1, N, 3), where N is the number of Gaussians." | |
| ), | |
| "singular_values_scales": ( | |
| "Scale factors for each Gaussian along its principal axes. " | |
| "Represents size and anisotropy. Shape: (1, N, 3)." | |
| ), | |
| "quaternions_rotations": ( | |
| "Rotation of each Gaussian as a unit quaternion [w, x, y, z]. " | |
| "Used to orient the ellipsoid. Shape: (1, N, 4)." | |
| ), | |
| "colors_rgb_linear": ( | |
| "RGB color values in linear RGB space (not gamma-corrected). " | |
| "Shape: (1, N, 3), with range [0, 1]." | |
| ), | |
| "opacities_alpha_channel": ( | |
| "Opacity value per Gaussian (alpha channel), used for blending. " | |
| "Shape: (1, N), where values are in [0, 1]." | |
| ), | |
| } | |
| class ToleranceConfig: | |
| """Tolerance configuration for validation.""" | |
| # Tolerances for random validation (tight) | |
| random_tolerances: dict[str, float] = None | |
| # Tolerances for real image validation (more lenient) | |
| image_tolerances: dict[str, float] = None | |
| # Angular tolerances for quaternions (in degrees) | |
| angular_tolerances_random: dict[str, float] = None | |
| angular_tolerances_image: dict[str, float] = None | |
| def __post_init__(self): | |
| if self.random_tolerances is None: | |
| self.random_tolerances = { | |
| "mean_vectors_3d_positions": 0.001, | |
| "singular_values_scales": 0.0001, | |
| "quaternions_rotations": 2.0, | |
| "colors_rgb_linear": 0.002, | |
| "opacities_alpha_channel": 0.005, | |
| } | |
| if self.image_tolerances is None: | |
| self.image_tolerances = { | |
| "mean_vectors_3d_positions": 3.5, # Increased to account for depth scaling with focal length | |
| "singular_values_scales": 0.035, # Increased proportionally (scales are depth-dependent) | |
| "quaternions_rotations": 5.0, | |
| "colors_rgb_linear": 0.01, | |
| "opacities_alpha_channel": 0.05, | |
| } | |
| if self.angular_tolerances_random is None: | |
| self.angular_tolerances_random = { | |
| "mean": 0.01, | |
| "p99": 0.1, | |
| "p99_9": 1.0, | |
| "max": 5.0, | |
| } | |
| if self.angular_tolerances_image is None: | |
| self.angular_tolerances_image = { | |
| "mean": 0.2, | |
| "p99": 2.0, | |
| "p99_9": 5.0, | |
| "max": 25.0, | |
| } | |
| class SharpModelTraceable(nn.Module): | |
| """Fully traceable version of SHARP for Core ML conversion. | |
| This version removes all dynamic control flow and makes the model | |
| fully traceable with torch.jit.trace. | |
| """ | |
| def __init__(self, predictor: RGBGaussianPredictor): | |
| """Initialize the traceable wrapper. | |
| Args: | |
| predictor: The SHARP RGBGaussianPredictor model. | |
| """ | |
| super().__init__() | |
| # Copy all submodules | |
| self.init_model = predictor.init_model | |
| self.feature_model = predictor.feature_model | |
| self.monodepth_model = predictor.monodepth_model | |
| self.prediction_head = predictor.prediction_head | |
| self.gaussian_composer = predictor.gaussian_composer | |
| self.depth_alignment = predictor.depth_alignment | |
| # For debugging: store global_scale | |
| self.last_global_scale = None | |
| self.last_monodepth_min = None | |
| def forward( | |
| self, | |
| image: torch.Tensor, | |
| disparity_factor: torch.Tensor | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Run inference with traceable forward pass. | |
| Args: | |
| image: Input image tensor of shape (1, 3, H, W) in range [0, 1]. | |
| disparity_factor: Disparity factor tensor of shape (1,). | |
| Returns: | |
| Tuple of 5 tensors representing 3D Gaussians. | |
| """ | |
| # Estimate depth using monodepth | |
| monodepth_output = self.monodepth_model(image) | |
| monodepth_disparity = monodepth_output.disparity | |
| # Convert disparity to depth - use float32 to match Core ML execution | |
| # Core ML uses float32 precision, so using double() here creates a mismatch | |
| disparity_factor_expanded = disparity_factor[:, None, None, None] | |
| # Clamp disparity to prevent numerical instability (matches model exactly) | |
| disparity_clamped = monodepth_disparity.clamp(min=1e-4, max=1e4) | |
| monodepth = disparity_factor_expanded / disparity_clamped | |
| # Apply depth alignment (inference mode) | |
| monodepth, _ = self.depth_alignment(monodepth, None, monodepth_output.decoder_features) | |
| # Store monodepth min for debugging (before normalization) | |
| if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | |
| self.last_monodepth_min = monodepth.flatten().min().item() | |
| # Initialize gaussians | |
| init_output = self.init_model(image, monodepth) | |
| # Store global_scale for debugging | |
| if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | |
| if init_output.global_scale is not None: | |
| self.last_global_scale = init_output.global_scale.item() | |
| # Extract features | |
| image_features = self.feature_model( | |
| init_output.feature_input, | |
| encodings=monodepth_output.output_features | |
| ) | |
| # Predict deltas | |
| delta_values = self.prediction_head(image_features) | |
| # Compose final gaussians | |
| gaussians = self.gaussian_composer( | |
| delta=delta_values, | |
| base_values=init_output.gaussian_base_values, | |
| global_scale=init_output.global_scale, | |
| ) | |
| # Normalize quaternions for consistent validation and inference | |
| # | |
| # IMPORTANT: The SHARP model does NOT canonicalize quaternions during inference. | |
| # Quaternions are normalized to unit length but retain their sign ambiguity (q ≡ -q). | |
| # | |
| # We canonicalize here for two reasons: | |
| # 1. Numerical validation: Ensures PyTorch and Core ML outputs can be compared directly | |
| # 2. Consistency: Provides deterministic outputs for the same rotation | |
| # | |
| # This canonicalization is NOT required for rendering, as both q and -q represent | |
| # the same 3D rotation. Renderers typically normalize quaternions internally. | |
| quaternions = gaussians.quaternions | |
| # Normalize quaternions to unit length | |
| # Use float32 to match Core ML precision | |
| quat_norm_sq = torch.sum(quaternions * quaternions, dim=-1, keepdim=True) | |
| quat_norm = torch.sqrt(torch.clamp(quat_norm_sq, min=1e-12)) | |
| quaternions_normalized = quaternions / quat_norm | |
| # Apply sign canonicalization for consistent representation | |
| # Ensure the component with largest absolute value is positive | |
| abs_quat = torch.abs(quaternions_normalized) | |
| max_idx = torch.argmax(abs_quat, dim=-1, keepdim=True) | |
| # Create one-hot selector for the max component | |
| one_hot = torch.zeros_like(quaternions_normalized) | |
| one_hot.scatter_(-1, max_idx, 1.0) | |
| # Get the sign of the max component | |
| max_component_sign = torch.sum(quaternions_normalized * one_hot, dim=-1, keepdim=True) | |
| # Canonicalize: flip if max component is negative | |
| # This matches the validation logic: np.where(max_component_sign < 0, -q, q) | |
| quaternions = torch.where(max_component_sign < 0, -quaternions_normalized, quaternions_normalized).float() | |
| return ( | |
| gaussians.mean_vectors, | |
| gaussians.singular_values, | |
| quaternions, | |
| gaussians.colors, | |
| gaussians.opacities, | |
| ) | |
| def load_sharp_model(checkpoint_path: Path | None = None) -> RGBGaussianPredictor: | |
| """Load SHARP model from checkpoint. | |
| Args: | |
| checkpoint_path: Path to the .pt checkpoint file. | |
| If None, downloads the default model. | |
| Returns: | |
| The loaded RGBGaussianPredictor model in eval mode. | |
| """ | |
| if checkpoint_path is None: | |
| LOGGER.info("Downloading default model from %s", DEFAULT_MODEL_URL) | |
| state_dict = torch.hub.load_state_dict_from_url(DEFAULT_MODEL_URL, progress=True) | |
| else: | |
| LOGGER.info("Loading checkpoint from %s", checkpoint_path) | |
| state_dict = torch.load(checkpoint_path, weights_only=True, map_location="cpu") | |
| # Create model with default parameters | |
| predictor = create_predictor(PredictorParams()) | |
| predictor.load_state_dict(state_dict) | |
| predictor.eval() | |
| return predictor | |
| def convert_to_coreml( | |
| predictor: RGBGaussianPredictor, | |
| output_path: Path, | |
| input_shape: tuple[int, int] = (1536, 1536), | |
| compute_precision: ct.precision = ct.precision.FLOAT16, | |
| compute_units: ct.ComputeUnit = ct.ComputeUnit.ALL, | |
| minimum_deployment_target: ct.target | None = None, | |
| ) -> ct.models.MLModel: | |
| """Convert SHARP model to Core ML format. | |
| Args: | |
| predictor: The SHARP RGBGaussianPredictor model. | |
| output_path: Path to save the .mlmodel file. | |
| input_shape: Input image shape (height, width). Default is (1536, 1536). | |
| compute_precision: Precision for compute (FLOAT16 or FLOAT32). | |
| compute_units: Target compute units (ALL, CPU_AND_GPU, CPU_ONLY, etc.). | |
| minimum_deployment_target: Minimum iOS/macOS deployment target. | |
| Returns: | |
| The converted Core ML model. | |
| """ | |
| LOGGER.info("Preparing model for Core ML conversion...") | |
| # Ensure depth alignment is disabled for inference | |
| predictor.depth_alignment.scale_map_estimator = None | |
| # Create traceable wrapper | |
| model_wrapper = SharpModelTraceable(predictor) | |
| model_wrapper.eval() | |
| # Pre-warm the model with a few forward passes for better tracing | |
| LOGGER.info("Pre-warming model for better tracing...") | |
| with torch.no_grad(): | |
| for _ in range(3): | |
| warm_image = torch.randn(1, 3, input_shape[0], input_shape[1]) | |
| warm_disparity = torch.tensor([1.0]) | |
| _ = model_wrapper(warm_image, warm_disparity) | |
| # Create deterministic example inputs for tracing (same as validation) | |
| height, width = input_shape | |
| torch.manual_seed(42) # Use same seed as validation for consistency | |
| example_image = torch.randn(1, 3, height, width) | |
| example_disparity_factor = torch.tensor([1.0]) | |
| LOGGER.info("Attempting torch.jit.script for better tracing...") | |
| try: | |
| with torch.no_grad(): | |
| scripted_model = torch.jit.script(model_wrapper) | |
| LOGGER.info("torch.jit.script succeeded, using scripted model") | |
| traced_model = scripted_model | |
| except Exception as e: | |
| LOGGER.warning(f"torch.jit.script failed: {e}") | |
| LOGGER.info("Falling back to torch.jit.trace...") | |
| with torch.no_grad(): | |
| traced_model = torch.jit.trace( | |
| model_wrapper, | |
| (example_image, example_disparity_factor), | |
| strict=False, # Allow some flexibility for complex models | |
| check_trace=False, # Skip trace checking to allow more flexibility | |
| ) | |
| LOGGER.info("Converting traced model to Core ML...") | |
| # Define input types for Core ML | |
| inputs = [ | |
| ct.TensorType( | |
| name="image", | |
| shape=(1, 3, height, width), | |
| dtype=np.float32, | |
| ), | |
| ct.TensorType( | |
| name="disparity_factor", | |
| shape=(1,), | |
| dtype=np.float32, | |
| ), | |
| ] | |
| # Define output names with clear, descriptive labels | |
| output_names = [ | |
| "mean_vectors_3d_positions", # 3D positions (NDC space) | |
| "singular_values_scales", # Scale parameters (diagonal of covariance) | |
| "quaternions_rotations", # Rotation as quaternions | |
| "colors_rgb_linear", # RGB colors in linear color space | |
| "opacities_alpha_channel", # Opacity values (alpha) | |
| ] | |
| # Define outputs with proper names for Core ML conversion | |
| outputs = [ | |
| ct.TensorType(name=output_names[0], dtype=np.float32), | |
| ct.TensorType(name=output_names[1], dtype=np.float32), | |
| ct.TensorType(name=output_names[2], dtype=np.float32), | |
| ct.TensorType(name=output_names[3], dtype=np.float32), | |
| ct.TensorType(name=output_names[4], dtype=np.float32), | |
| ] | |
| # Set up conversion config | |
| conversion_kwargs: dict[str, Any] = { | |
| "inputs": inputs, | |
| "outputs": outputs, # Specify output names during conversion | |
| "convert_to": "mlprogram", # Use ML Program format for better performance | |
| "compute_precision": compute_precision, | |
| "compute_units": compute_units, | |
| } | |
| if minimum_deployment_target is not None: | |
| conversion_kwargs["minimum_deployment_target"] = minimum_deployment_target | |
| # Convert to Core ML | |
| mlmodel = ct.convert( | |
| traced_model, | |
| **conversion_kwargs, | |
| ) | |
| # Add metadata | |
| mlmodel.author = "Apple Inc." | |
| mlmodel.license = "See LICENSE_MODEL in ml-sharp repository" | |
| mlmodel.short_description = ( | |
| "SHARP: Sharp Monocular View Synthesis - Predicts 3D Gaussian splats from a single image" | |
| ) | |
| mlmodel.version = "1.0.0" | |
| # Update output names and descriptions via spec BEFORE saving | |
| spec = mlmodel.get_spec() | |
| # Input descriptions | |
| input_descriptions = { | |
| "image": "RGB image normalized to [0, 1], shape (1, 3, H, W)", | |
| "disparity_factor": "Focal length / image width ratio, shape (1,)", | |
| } | |
| # Output descriptions with clear intent and units | |
| output_descriptions = { | |
| "mean_vectors_3d_positions": ( | |
| "3D positions of Gaussian splats in normalized device coordinates (NDC). " | |
| "Shape: (1, N, 3), where N is the number of Gaussians." | |
| ), | |
| "singular_values_scales": ( | |
| "Scale factors for each Gaussian along its principal axes. " | |
| "Represents size and anisotropy. Shape: (1, N, 3)." | |
| ), | |
| "quaternions_rotations": ( | |
| "Rotation of each Gaussian as a unit quaternion [w, x, y, z]. " | |
| "Used to orient the ellipsoid. Shape: (1, N, 4)." | |
| ), | |
| "colors_rgb_linear": ( | |
| "RGB color values in linear RGB space (not gamma-corrected). " | |
| "Shape: (1, N, 3), with range [0, 1]." | |
| ), | |
| "opacities_alpha_channel": ( | |
| "Opacity value per Gaussian (alpha channel), used for blending. " | |
| "Shape: (1, N), where values are in [0, 1]." | |
| ), | |
| } | |
| # Update output names and descriptions | |
| for i, name in enumerate(output_names): | |
| if i < len(spec.description.output): | |
| output = spec.description.output[i] | |
| output.name = name # Update name | |
| output.shortDescription = output_descriptions[name] # Add description | |
| # Validate output names are set correctly | |
| LOGGER.info("Output names after update: %s", [o.name for o in spec.description.output]) | |
| # Save the model with correct names | |
| LOGGER.info("Saving Core ML model to %s", output_path) | |
| mlmodel.save(str(output_path)) | |
| return mlmodel | |
| class QuaternionValidator: | |
| """Validator for quaternion comparisons with configurable tolerances and outlier analysis.""" | |
| DEFAULT_ANGULAR_TOLERANCES = { | |
| "mean": 0.01, | |
| "p99": 0.5, | |
| "p99_9": 2.0, | |
| "max": 15.0, | |
| } | |
| def __init__( | |
| self, | |
| angular_tolerances: dict[str, float] | None = None, | |
| enable_outlier_analysis: bool = True, | |
| outlier_thresholds: list[float] | None = None, | |
| ): | |
| """Initialize validator with tolerances. | |
| Args: | |
| angular_tolerances: Dict with keys 'mean', 'p99', 'p99_9', 'max' for angular diffs in degrees. | |
| enable_outlier_analysis: Whether to perform detailed outlier analysis. | |
| outlier_thresholds: List of angle thresholds for outlier counting. | |
| """ | |
| self.angular_tolerances = angular_tolerances or self.DEFAULT_ANGULAR_TOLERANCES.copy() | |
| self.enable_outlier_analysis = enable_outlier_analysis | |
| self.outlier_thresholds = outlier_thresholds or [5.0, 10.0, 15.0] | |
| def canonicalize_quaternion(q: np.ndarray) -> np.ndarray: | |
| """Canonicalize quaternion to ensure consistent representation. | |
| Ensures the quaternion with the largest absolute component is positive. | |
| This handles the sign ambiguity where q and -q represent the same rotation. | |
| Args: | |
| q: Quaternion array of shape (..., 4) | |
| Returns: | |
| Canonicalized quaternion array. | |
| """ | |
| abs_q = np.abs(q) | |
| max_component_idx = np.argmax(abs_q, axis=-1, keepdims=True) | |
| selector = np.zeros_like(q) | |
| np.put_along_axis(selector, max_component_idx, 1.0, axis=-1) | |
| max_component_sign = np.sum(q * selector, axis=-1, keepdims=True) | |
| return np.where(max_component_sign < 0, -q, q) | |
| def compute_angular_differences( | |
| quats1: np.ndarray, quats2: np.ndarray | |
| ) -> tuple[np.ndarray, dict[str, float]]: | |
| """Compute angular differences between two sets of quaternions. | |
| Args: | |
| quats1: First set of quaternions shape (N, 4) | |
| quats2: Second set of quaternions shape (N, 4) | |
| Returns: | |
| Tuple of (angular_differences in degrees, statistics dict) | |
| """ | |
| # Normalize quaternions | |
| norm1 = np.linalg.norm(quats1, axis=-1, keepdims=True) | |
| norm2 = np.linalg.norm(quats2, axis=-1, keepdims=True) | |
| quats1_norm = quats1 / np.clip(norm1, 1e-12, None) | |
| quats2_norm = quats2 / np.clip(norm2, 1e-12, None) | |
| # Canonicalize both | |
| quats1_canon = QuaternionValidator.canonicalize_quaternion(quats1_norm) | |
| quats2_canon = QuaternionValidator.canonicalize_quaternion(quats2_norm) | |
| # Compute dot products for both q·q and q·(-q) to handle sign ambiguity | |
| dot_products = np.sum(quats1_canon * quats2_canon, axis=-1) | |
| dot_products_flipped = np.sum(quats1_canon * (-quats2_canon), axis=-1) | |
| # Take the maximum absolute dot product (handle sign ambiguity) | |
| dot_products = np.maximum(np.abs(dot_products), np.abs(dot_products_flipped)) | |
| dot_products = np.clip(dot_products, 0.0, 1.0) | |
| # Compute angular differences | |
| angular_diff_rad = 2.0 * np.arccos(dot_products) | |
| angular_diff_deg = np.degrees(angular_diff_rad) | |
| # Compute statistics | |
| stats = { | |
| "mean": float(np.mean(angular_diff_deg)), | |
| "std": float(np.std(angular_diff_deg)), | |
| "min": float(np.min(angular_diff_deg)), | |
| "max": float(np.max(angular_diff_deg)), | |
| "p50": float(np.percentile(angular_diff_deg, 50)), | |
| "p90": float(np.percentile(angular_diff_deg, 90)), | |
| "p99": float(np.percentile(angular_diff_deg, 99)), | |
| "p99_9": float(np.percentile(angular_diff_deg, 99.9)), | |
| } | |
| return angular_diff_deg, stats | |
| def analyze_outliers( | |
| self, angular_diff_deg: np.ndarray | |
| ) -> dict[str, dict[str, int | float]]: | |
| """Analyze outliers in angular differences. | |
| Args: | |
| angular_diff_deg: Array of angular differences in degrees. | |
| Returns: | |
| Dict with outlier statistics for each threshold. | |
| """ | |
| if not self.enable_outlier_analysis: | |
| return {} | |
| outlier_stats = {} | |
| total = len(angular_diff_deg) | |
| for threshold in self.outlier_thresholds: | |
| count = int(np.sum(angular_diff_deg > threshold)) | |
| outlier_stats[f">{threshold}°"] = { | |
| "count": count, | |
| "percentage": (count / total) * 100.0 if total > 0 else 0.0, | |
| } | |
| return outlier_stats | |
| def validate( | |
| self, | |
| pt_quaternions: np.ndarray, | |
| coreml_quaternions: np.ndarray, | |
| image_name: str = "Unknown", | |
| ) -> dict: | |
| """Validate Core ML quaternions against PyTorch quaternions. | |
| Args: | |
| pt_quaternions: PyTorch quaternion outputs. | |
| coreml_quaternions: Core ML quaternion outputs. | |
| image_name: Name of the image being validated. | |
| Returns: | |
| Dict with validation results including status, stats, and outliers. | |
| """ | |
| angular_diff_deg, stats = self.compute_angular_differences( | |
| pt_quaternions, coreml_quaternions | |
| ) | |
| outlier_stats = self.analyze_outliers(angular_diff_deg) | |
| # Check tolerances | |
| passed = True | |
| failure_reasons = [] | |
| for key, tolerance in self.angular_tolerances.items(): | |
| if key in stats and stats[key] > tolerance: | |
| passed = False | |
| failure_reasons.append( | |
| f"{key} angular {stats[key]:.4f}° > tolerance {tolerance:.4f}°" | |
| ) | |
| return { | |
| "image": image_name, | |
| "passed": passed, | |
| "failure_reasons": failure_reasons, | |
| "stats": stats, | |
| "outliers": outlier_stats, | |
| "num_gaussians": len(angular_diff_deg), | |
| } | |
| def find_coreml_output_key(name: str, coreml_outputs: dict) -> str: | |
| """Find matching Core ML output key for a given output name. | |
| Args: | |
| name: The expected output name | |
| coreml_outputs: Dictionary of Core ML outputs | |
| Returns: | |
| The matching key from coreml_outputs | |
| """ | |
| if name in coreml_outputs: | |
| return name | |
| # Try partial match | |
| for key in coreml_outputs: | |
| base_name = name.split('_')[0] | |
| if base_name in key.lower(): | |
| return key | |
| # Fallback to index-based lookup | |
| output_index = OUTPUT_NAMES.index(name) if name in OUTPUT_NAMES else 0 | |
| return list(coreml_outputs.keys())[output_index] | |
| def run_inference_pair( | |
| pytorch_model: RGBGaussianPredictor, | |
| mlmodel: ct.models.MLModel, | |
| image_tensor: torch.Tensor, | |
| disparity_factor: float = 1.0, | |
| log_internals: bool = False, | |
| ) -> tuple[list[np.ndarray], dict[str, np.ndarray]]: | |
| """Run inference on both PyTorch and Core ML models. | |
| Args: | |
| pytorch_model: The PyTorch model | |
| mlmodel: The Core ML model | |
| image_tensor: Input image tensor | |
| disparity_factor: Disparity factor value | |
| log_internals: Whether to log internal values for debugging | |
| Returns: | |
| Tuple of (pytorch_outputs, coreml_outputs) | |
| """ | |
| # Run PyTorch model | |
| traceable_wrapper = SharpModelTraceable(pytorch_model) | |
| traceable_wrapper.eval() | |
| # Ensure float32 dtype for model inference | |
| image_tensor = image_tensor.float() | |
| test_disparity_pt = torch.tensor([disparity_factor], dtype=torch.float32) | |
| with torch.no_grad(): | |
| pt_outputs = traceable_wrapper(image_tensor, test_disparity_pt) | |
| # Log internal values if requested | |
| if log_internals: | |
| if hasattr(traceable_wrapper, 'last_global_scale') and traceable_wrapper.last_global_scale is not None: | |
| LOGGER.info(f"PyTorch global_scale: {traceable_wrapper.last_global_scale:.6f}") | |
| if hasattr(traceable_wrapper, 'last_monodepth_min') and traceable_wrapper.last_monodepth_min is not None: | |
| LOGGER.info(f"PyTorch monodepth_min: {traceable_wrapper.last_monodepth_min:.6f}") | |
| # Convert to numpy | |
| pt_outputs_np = [o.numpy() for o in pt_outputs] | |
| # Run Core ML model | |
| test_image_np = image_tensor.numpy() | |
| test_disparity_np = np.array([disparity_factor], dtype=np.float32) | |
| coreml_inputs = { | |
| "image": test_image_np, | |
| "disparity_factor": test_disparity_np, | |
| } | |
| coreml_outputs = mlmodel.predict(coreml_inputs) | |
| return pt_outputs_np, coreml_outputs | |
| def compare_outputs( | |
| pt_outputs: list[np.ndarray], | |
| coreml_outputs: dict[str, np.ndarray], | |
| tolerances: dict[str, float], | |
| quat_validator: QuaternionValidator, | |
| image_name: str = "Unknown", | |
| ) -> list[dict]: | |
| """Compare PyTorch and Core ML outputs. | |
| Args: | |
| pt_outputs: List of PyTorch outputs | |
| coreml_outputs: Dictionary of Core ML outputs | |
| tolerances: Tolerance values per output type | |
| quat_validator: QuaternionValidator instance | |
| image_name: Name of the image being validated | |
| Returns: | |
| List of validation result dictionaries | |
| """ | |
| validation_results = [] | |
| for i, name in enumerate(OUTPUT_NAMES): | |
| pt_output = pt_outputs[i] | |
| coreml_key = find_coreml_output_key(name, coreml_outputs) | |
| coreml_output = coreml_outputs[coreml_key] | |
| result = {"output": name, "passed": True, "failure_reason": ""} | |
| if name == "quaternions_rotations": | |
| # Use QuaternionValidator for quaternions | |
| quat_result = quat_validator.validate(pt_output, coreml_output, image_name=image_name) | |
| result.update({ | |
| "max_diff": f"{quat_result['stats']['max']:.6f}", | |
| "mean_diff": f"{quat_result['stats']['mean']:.6f}", | |
| "p99_diff": f"{quat_result['stats']['p99']:.6f}", | |
| "passed": quat_result["passed"], | |
| "failure_reason": "; ".join(quat_result["failure_reasons"]) if quat_result["failure_reasons"] else "", | |
| }) | |
| else: | |
| # Standard numerical comparison | |
| diff = np.abs(pt_output - coreml_output) | |
| output_tolerance = tolerances.get(name, 0.01) | |
| max_diff = np.max(diff) | |
| result.update({ | |
| "max_diff": f"{max_diff:.6f}", | |
| "mean_diff": f"{np.mean(diff):.6f}", | |
| "p99_diff": f"{np.percentile(diff, 99):.6f}", | |
| }) | |
| if max_diff > output_tolerance: | |
| result["passed"] = False | |
| result["failure_reason"] = f"max diff {max_diff:.6f} > tolerance {output_tolerance:.6f}" | |
| validation_results.append(result) | |
| return validation_results | |
| def format_validation_table( | |
| validation_results: list[dict], | |
| image_name: str, | |
| include_image_column: bool = False, | |
| ) -> str: | |
| """Format validation results as a markdown table. | |
| Args: | |
| validation_results: List of validation result dicts with keys: | |
| output, max_diff, mean_diff, p99_diff, passed, etc. | |
| image_name: Name of the image being validated. | |
| include_image_column: Whether to include the image name as a column. | |
| Returns: | |
| Formatted markdown table as a string. | |
| """ | |
| lines = [] | |
| if include_image_column: | |
| lines.append("| Image | Output | Max Diff | Mean Diff | P99 Diff | Status |") | |
| lines.append("|-------|--------|----------|-----------|----------|--------|") | |
| for result in validation_results: | |
| output_name = result["output"].replace("_", " ").title() | |
| status = "✅ PASS" if result["passed"] else "❌ FAIL" | |
| lines.append( | |
| f"| {image_name} | {output_name} | {result['max_diff']} | " | |
| f"{result['mean_diff']} | {result['p99_diff']} | {status} |" | |
| ) | |
| else: | |
| lines.append("| Output | Max Diff | Mean Diff | P99 Diff | Status |") | |
| lines.append("|--------|----------|-----------|----------|--------|") | |
| for result in validation_results: | |
| output_name = result["output"].replace("_", " ").title() | |
| status = "✅ PASS" if result["passed"] else "❌ FAIL" | |
| lines.append( | |
| f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | " | |
| f"{result['p99_diff']} | {status} |" | |
| ) | |
| return "\n".join(lines) | |
| def validate_coreml_model( | |
| mlmodel: ct.models.MLModel, | |
| pytorch_model: RGBGaussianPredictor, | |
| input_shape: tuple[int, int] = (1536, 1536), | |
| tolerance: float = 0.01, | |
| angular_tolerances: dict[str, float] | None = None, | |
| ) -> bool: | |
| """Validate Core ML model outputs against PyTorch model. | |
| Args: | |
| mlmodel: The Core ML model to validate. | |
| pytorch_model: The original PyTorch model. | |
| input_shape: Input image shape (height, width). | |
| tolerance: Maximum allowed difference between outputs. | |
| angular_tolerances: Dict with keys 'mean', 'p99', 'p99_9', 'max' for angular diffs in degrees. | |
| Returns: | |
| True if validation passes, False otherwise. | |
| """ | |
| LOGGER.info("Validating Core ML model against PyTorch...") | |
| height, width = input_shape | |
| # Set seeds for reproducibility | |
| np.random.seed(42) | |
| torch.manual_seed(42) | |
| # Create test input | |
| test_image_np = np.random.rand(1, 3, height, width).astype(np.float32) | |
| test_disparity = np.array([1.0], dtype=np.float32) | |
| # Run PyTorch model | |
| test_image_pt = torch.from_numpy(test_image_np) | |
| test_disparity_pt = torch.from_numpy(test_disparity) | |
| traceable_wrapper = SharpModelTraceable(pytorch_model) | |
| traceable_wrapper.eval() | |
| with torch.no_grad(): | |
| pt_outputs = traceable_wrapper(test_image_pt, test_disparity_pt) | |
| # Run Core ML model | |
| coreml_inputs = { | |
| "image": test_image_np, | |
| "disparity_factor": test_disparity, | |
| } | |
| coreml_outputs = mlmodel.predict(coreml_inputs) | |
| LOGGER.info(f"PyTorch outputs shapes: {[o.shape for o in pt_outputs]}") | |
| LOGGER.info(f"Core ML outputs keys: {list(coreml_outputs.keys())}") | |
| # Output configuration | |
| output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"] | |
| # Define tolerances per output type | |
| tolerances = { | |
| "mean_vectors_3d_positions": 0.001, | |
| "singular_values_scales": 0.0001, | |
| "quaternions_rotations": 2.0, | |
| "colors_rgb_linear": 0.002, | |
| "opacities_alpha_channel": 0.005, | |
| } | |
| # Use provided angular tolerances or defaults | |
| if angular_tolerances is None: | |
| angular_tolerances = { | |
| "mean": 0.01, | |
| "p99": 0.1, | |
| "p99_9": 1.0, | |
| "max": 5.0, | |
| } | |
| # Initialize quaternion validator | |
| quat_validator = QuaternionValidator(angular_tolerances=angular_tolerances) | |
| all_passed = True | |
| # Additional diagnostics for depth/position analysis | |
| LOGGER.info("=== Depth/Position Statistics ===") | |
| pt_positions = pt_outputs[0].numpy() | |
| coreml_key = [k for k in coreml_outputs.keys() if "mean_vectors" in k][0] | |
| coreml_positions = coreml_outputs[coreml_key] | |
| LOGGER.info(f"PyTorch positions - Z range: [{pt_positions[..., 2].min():.4f}, {pt_positions[..., 2].max():.4f}], mean: {pt_positions[..., 2].mean():.4f}, std: {pt_positions[..., 2].std():.4f}") | |
| LOGGER.info(f"CoreML positions - Z range: [{coreml_positions[..., 2].min():.4f}, {coreml_positions[..., 2].max():.4f}], mean: {coreml_positions[..., 2].mean():.4f}, std: {coreml_positions[..., 2].std():.4f}") | |
| z_diff = np.abs(pt_positions[..., 2] - coreml_positions[..., 2]) | |
| LOGGER.info(f"Z-coordinate difference - max: {z_diff.max():.6f}, mean: {z_diff.mean():.6f}, std: {z_diff.std():.6f}") | |
| LOGGER.info("=================================") | |
| # Collect validation results | |
| validation_results = [] | |
| for i, name in enumerate(output_names): | |
| pt_output = pt_outputs[i].numpy() | |
| # Find matching Core ML output | |
| coreml_key = None | |
| if name in coreml_outputs: | |
| coreml_key = name | |
| else: | |
| # Try partial match | |
| for key in coreml_outputs: | |
| base_name = name.split('_')[0] | |
| if base_name in key.lower(): | |
| coreml_key = key | |
| break | |
| if coreml_key is None: | |
| coreml_key = list(coreml_outputs.keys())[i] | |
| coreml_output = coreml_outputs[coreml_key] | |
| result = {"output": name, "passed": True, "failure_reason": ""} | |
| # Special handling for quaternions | |
| if name == "quaternions_rotations": | |
| # Use the new QuaternionValidator | |
| quat_result = quat_validator.validate(pt_output, coreml_output, image_name="Random") | |
| result.update({ | |
| "max_diff": f"{quat_result['stats']['max']:.6f}", | |
| "mean_diff": f"{quat_result['stats']['mean']:.6f}", | |
| "p99_diff": f"{quat_result['stats']['p99']:.6f}", | |
| "p99_9_diff": f"{quat_result['stats']['p99_9']:.6f}", | |
| "max_angular": f"{quat_result['stats']['max']:.4f}", | |
| "mean_angular": f"{quat_result['stats']['mean']:.4f}", | |
| "p99_angular": f"{quat_result['stats']['p99']:.4f}", | |
| "passed": quat_result["passed"], | |
| "failure_reason": "; ".join(quat_result["failure_reasons"]) if quat_result["failure_reasons"] else "", | |
| "quat_stats": quat_result["stats"], | |
| "outliers": quat_result["outliers"], | |
| }) | |
| if not quat_result["passed"]: | |
| all_passed = False | |
| else: | |
| diff = np.abs(pt_output - coreml_output) | |
| output_tolerance = tolerances.get(name, tolerance) | |
| result.update({ | |
| "max_diff": f"{np.max(diff):.6f}", | |
| "mean_diff": f"{np.mean(diff):.6f}", | |
| "p99_diff": f"{np.percentile(diff, 99):.6f}", | |
| "tolerance": f"{output_tolerance:.6f}" | |
| }) | |
| if np.max(diff) > output_tolerance: | |
| result["passed"] = False | |
| result["failure_reason"] = f"max diff {np.max(diff):.6f} > tolerance {output_tolerance:.6f}" | |
| all_passed = False | |
| validation_results.append(result) | |
| # Output validation results as markdown table | |
| LOGGER.info("\n### Validation Results\n") | |
| LOGGER.info("| Output | Max Diff | Mean Diff | P99 Diff | P99.9 Diff | Angular Diff (°) | Status |") | |
| LOGGER.info("|--------|----------|-----------|----------|------------|------------------|--------|") | |
| for result in validation_results: | |
| output_name = result["output"].replace("_", " ").title() | |
| if "max_angular" in result: | |
| angular_info = f"{result['max_angular']} / {result['mean_angular']} / {result['p99_angular']}" | |
| p99_9 = result.get("p99_9_diff", "-") | |
| status = "✅ PASS" if result["passed"] else f"❌ FAIL" | |
| LOGGER.info(f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | {result['p99_diff']} | {p99_9} | {angular_info} | {status} |") | |
| else: | |
| status = "✅ PASS" if result["passed"] else f"❌ FAIL" | |
| LOGGER.info(f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | {result['p99_diff']} | - | - | {status} |") | |
| LOGGER.info("") | |
| # Output quaternion outlier analysis if available | |
| for result in validation_results: | |
| if "outliers" in result and result["outliers"]: | |
| LOGGER.info("### Quaternion Outlier Analysis\n") | |
| LOGGER.info(f"| Threshold | Count | Percentage |") | |
| LOGGER.info("|-----------|-------|------------|") | |
| for threshold, data in result["outliers"].items(): | |
| LOGGER.info(f"| {threshold} | {data['count']} | {data['percentage']:.4f}% |") | |
| LOGGER.info("") | |
| return all_passed | |
| def load_and_preprocess_image( | |
| image_path: Path, | |
| target_size: tuple[int, int] = (1536, 1536), | |
| ) -> tuple[torch.Tensor, float, tuple[int, int]]: | |
| """Load and preprocess an input image for SHARP inference. | |
| Args: | |
| image_path: Path to the input image file. | |
| target_size: Target (height, width) for resizing. | |
| Returns: | |
| Tuple of (preprocessed image tensor, focal_length_px, original_size) | |
| - Preprocessed image tensor of shape (1, 3, H, W) in range [0, 1] | |
| - Focal length in pixels (from EXIF or default) | |
| - Original image size (width, height) | |
| """ | |
| LOGGER.info(f"Loading image from {image_path}") | |
| # Use the SHARP io utilities to load image with focal length | |
| image_np, original_size, f_px = io.load_rgb(image_path) | |
| LOGGER.info(f"Original image size: {original_size}, focal length: {f_px:.2f}px") | |
| # Convert to torch and normalize - ensure float32 dtype | |
| # io.load_rgb returns uint8, convert to float32 explicitly | |
| image_tensor = torch.from_numpy(image_np).float() / 255.0 | |
| image_tensor = image_tensor.permute(2, 0, 1) # HWC -> CHW | |
| original_height, original_width = image_np.shape[:2] | |
| # Resize to target size if different | |
| if (original_width, original_height) != (target_size[1], target_size[0]): | |
| LOGGER.info(f"Resizing to {target_size[1]}x{target_size[0]}") | |
| import torch.nn.functional as F | |
| image_tensor = F.interpolate( | |
| image_tensor.unsqueeze(0), | |
| size=(target_size[0], target_size[1]), | |
| mode="bilinear", | |
| align_corners=True, | |
| ).squeeze(0) | |
| # Add batch dimension | |
| image_tensor = image_tensor.unsqueeze(0) # (1, 3, H, W) | |
| LOGGER.info(f"Preprocessed image shape: {image_tensor.shape}, range: [{image_tensor.min():.4f}, {image_tensor.max():.4f}]") | |
| return image_tensor, f_px, (original_width, original_height) | |
| def validate_with_image( | |
| mlmodel: ct.models.MLModel, | |
| pytorch_model: RGBGaussianPredictor, | |
| image_path: Path, | |
| input_shape: tuple[int, int] = (1536, 1536), | |
| ) -> bool: | |
| """Validate Core ML model outputs against PyTorch model using a real input image. | |
| Args: | |
| mlmodel: The Core ML model to validate. | |
| pytorch_model: The original PyTorch model. | |
| image_path: Path to the input image file. | |
| input_shape: Expected input image shape (height, width). | |
| Returns: | |
| True if validation passes, False otherwise. | |
| """ | |
| LOGGER.info("=" * 60) | |
| LOGGER.info("Validating Core ML model against PyTorch with real image") | |
| LOGGER.info("=" * 60) | |
| # Load and preprocess the input image | |
| test_image = load_and_preprocess_image(image_path, input_shape) | |
| test_disparity = np.array([1.0], dtype=np.float32) | |
| # Run PyTorch model | |
| traceable_wrapper = SharpModelTraceable(pytorch_model) | |
| traceable_wrapper.eval() | |
| with torch.no_grad(): | |
| pt_outputs = traceable_wrapper(test_image, torch.from_numpy(test_disparity)) | |
| LOGGER.info(f"PyTorch outputs shapes: {[o.shape for o in pt_outputs]}") | |
| # Run Core ML model | |
| test_image_np = test_image.numpy() | |
| coreml_inputs = { | |
| "image": test_image_np, | |
| "disparity_factor": test_disparity, | |
| } | |
| coreml_outputs = mlmodel.predict(coreml_inputs) | |
| LOGGER.info(f"Core ML outputs keys: {list(coreml_outputs.keys())}") | |
| # Output configuration | |
| output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"] | |
| # Define tolerances per output type for real image validation | |
| # Using p99-based tolerances to handle outliers better | |
| tolerances = { | |
| "mean_vectors_3d_positions": 1.2, | |
| "singular_values_scales": 0.01, | |
| "quaternions_rotations": 5.0, | |
| "colors_rgb_linear": 0.01, | |
| "opacities_alpha_channel": 0.05, | |
| } | |
| # Angular tolerances for quaternions (in degrees) | |
| angular_tolerances = { | |
| "mean": 0.1, | |
| "p99": 1.0, | |
| "max": 15.0, | |
| } | |
| all_passed = True | |
| # Log input image statistics | |
| LOGGER.info(f"\n=== Input Image Statistics ===") | |
| LOGGER.info(f"Image path: {image_path}") | |
| LOGGER.info(f"Image shape: {test_image.shape}") | |
| LOGGER.info(f"Image range: [{test_image.min():.4f}, {test_image.max():.4f}]") | |
| LOGGER.info(f"Image mean: {test_image.mean(dim=[1,2,3]).tolist()}") | |
| LOGGER.info("=" * 30) | |
| # Depth/position analysis | |
| pt_positions = pt_outputs[0].numpy() | |
| coreml_key = [k for k in coreml_outputs.keys() if "mean_vectors" in k][0] | |
| coreml_positions = coreml_outputs[coreml_key] | |
| LOGGER.info("\n=== Depth/Position Statistics ===") | |
| LOGGER.info(f"PyTorch positions - Z range: [{pt_positions[..., 2].min():.4f}, {pt_positions[..., 2].max():.4f}], mean: {pt_positions[..., 2].mean():.4f}, std: {pt_positions[..., 2].std():.4f}") | |
| LOGGER.info(f"CoreML positions - Z range: [{coreml_positions[..., 2].min():.4f}, {coreml_positions[..., 2].max():.4f}], mean: {coreml_positions[..., 2].mean():.4f}, std: {coreml_positions[..., 2].std():.4f}") | |
| z_diff = np.abs(pt_positions[..., 2] - coreml_positions[..., 2]) | |
| LOGGER.info(f"Z-coordinate difference - max: {z_diff.max():.6f}, mean: {z_diff.mean():.6f}, std: {z_diff.std():.6f}") | |
| LOGGER.info("=================================\n") | |
| # Collect validation results | |
| validation_results = [] | |
| for i, name in enumerate(output_names): | |
| pt_output = pt_outputs[i].numpy() | |
| # Find matching Core ML output | |
| coreml_key = None | |
| if name in coreml_outputs: | |
| coreml_key = name | |
| else: | |
| # Try partial match | |
| for key in coreml_outputs: | |
| base_name = name.split('_')[0] | |
| if base_name in key.lower(): | |
| coreml_key = key | |
| break | |
| if coreml_key is None: | |
| coreml_key = list(coreml_outputs.keys())[i] | |
| coreml_output = coreml_outputs[coreml_key] | |
| result = {"output": name, "passed": True, "failure_reason": ""} | |
| # Special handling for quaternions | |
| if name == "quaternions_rotations": | |
| pt_quat_norm = np.linalg.norm(pt_output, axis=-1, keepdims=True) | |
| pt_output_normalized = pt_output / np.clip(pt_quat_norm, 1e-12, None) | |
| coreml_quat_norm = np.linalg.norm(coreml_output, axis=-1, keepdims=True) | |
| coreml_output_normalized = coreml_output / np.clip(coreml_quat_norm, 1e-12, None) | |
| def canonicalize_quaternion(q): | |
| abs_q = np.abs(q) | |
| max_component_idx = np.argmax(abs_q, axis=-1, keepdims=True) | |
| selector = np.zeros_like(q) | |
| np.put_along_axis(selector, max_component_idx, 1, axis=-1) | |
| max_component_sign = np.sum(q * selector, axis=-1, keepdims=True) | |
| return np.where(max_component_sign < 0, -q, q) | |
| pt_output_canonical = canonicalize_quaternion(pt_output_normalized) | |
| coreml_output_canonical = canonicalize_quaternion(coreml_output_normalized) | |
| diff = np.abs(pt_output_canonical - coreml_output_canonical) | |
| dot_products = np.sum(pt_output_canonical * coreml_output_canonical, axis=-1) | |
| dot_products_flipped = np.sum(pt_output_canonical * (-coreml_output_canonical), axis=-1) | |
| # Take the absolute value and ensure we compare q with -q if needed | |
| # This handles the sign ambiguity: q and -q represent the same rotation | |
| dot_products = np.where( | |
| np.abs(dot_products) > np.abs(dot_products_flipped), | |
| np.abs(dot_products), | |
| np.abs(dot_products_flipped) | |
| ) | |
| dot_products = np.clip(dot_products, 0.0, 1.0) | |
| angular_diff_rad = 2 * np.arccos(dot_products) | |
| angular_diff_deg = np.degrees(angular_diff_rad) | |
| max_angular = np.max(angular_diff_deg) | |
| mean_angular = np.mean(angular_diff_deg) | |
| p99_angular = np.percentile(angular_diff_deg, 99) | |
| quat_passed = True | |
| failure_reasons = [] | |
| if mean_angular > angular_tolerances["mean"]: | |
| quat_passed = False | |
| failure_reasons.append(f"mean angular {mean_angular:.4f}° > {angular_tolerances['mean']:.4f}°") | |
| if p99_angular > angular_tolerances["p99"]: | |
| quat_passed = False | |
| failure_reasons.append(f"p99 angular {p99_angular:.4f}° > {angular_tolerances['p99']:.4f}°") | |
| if max_angular > angular_tolerances["max"]: | |
| quat_passed = False | |
| failure_reasons.append(f"max angular {max_angular:.4f}° > {angular_tolerances['max']:.4f}°") | |
| result.update({ | |
| "max_diff": f"{np.max(diff):.6f}", | |
| "mean_diff": f"{np.mean(diff):.6f}", | |
| "p99_diff": f"{np.percentile(diff, 99):.6f}", | |
| "max_angular": f"{max_angular:.4f}", | |
| "mean_angular": f"{mean_angular:.4f}", | |
| "p99_angular": f"{p99_angular:.4f}", | |
| "passed": quat_passed, | |
| "failure_reason": "; ".join(failure_reasons) if failure_reasons else "" | |
| }) | |
| if not quat_passed: | |
| all_passed = False | |
| else: | |
| diff = np.abs(pt_output - coreml_output) | |
| output_tolerance = tolerances.get(name, 0.01) | |
| result.update({ | |
| "max_diff": f"{np.max(diff):.6f}", | |
| "mean_diff": f"{np.mean(diff):.6f}", | |
| "p99_diff": f"{np.percentile(diff, 99):.6f}", | |
| "tolerance": f"{output_tolerance:.6f}" | |
| }) | |
| if np.max(diff) > output_tolerance: | |
| result["passed"] = False | |
| result["failure_reason"] = f"max diff {np.max(diff):.6f} > tolerance {output_tolerance:.6f}" | |
| all_passed = False | |
| validation_results.append(result) | |
| # Output validation results as markdown table | |
| LOGGER.info("\n### Image Validation Results\n") | |
| LOGGER.info(f"| Output | Max Diff | Mean Diff | P99 Diff | Angular Diff (°) | Status |") | |
| LOGGER.info(f"|--------|----------|-----------|----------|------------------|--------|") | |
| for result in validation_results: | |
| output_name = result["output"].replace("_", " ").title() | |
| if "max_angular" in result: | |
| angular_info = f"{result['max_angular']} / {result['mean_angular']} / {result['p99_angular']}" | |
| else: | |
| angular_info = "-" | |
| status = "✅ PASS" if result["passed"] else f"❌ FAIL" | |
| LOGGER.info(f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | {result['p99_diff']} | {angular_info} | {status} |") | |
| LOGGER.info("") | |
| return all_passed | |
| def validate_with_image_set( | |
| mlmodel: ct.models.MLModel, | |
| pytorch_model: RGBGaussianPredictor, | |
| image_paths: list[Path], | |
| input_shape: tuple[int, int] = (1536, 1536), | |
| ) -> bool: | |
| """Validate Core ML model against PyTorch using multiple input images. | |
| Args: | |
| mlmodel: The Core ML model to validate. | |
| pytorch_model: The original PyTorch model. | |
| image_paths: List of paths to input images for validation. | |
| input_shape: Expected input image shape (height, width). | |
| Returns: | |
| True if all validations pass, False otherwise. | |
| """ | |
| LOGGER.info("=" * 60) | |
| LOGGER.info(f"Validating Core ML model with {len(image_paths)} images") | |
| LOGGER.info("=" * 60) | |
| # Angular tolerances for image validation (more lenient than random validation) | |
| # Real images have more variation than random noise | |
| angular_tolerances = { | |
| "mean": 0.2, | |
| "p99": 2.0, | |
| "p99_9": 5.0, | |
| "max": 25.0, | |
| } | |
| # Initialize quaternion validator | |
| quat_validator = QuaternionValidator(angular_tolerances=angular_tolerances) | |
| all_passed = True | |
| all_validation_results = [] | |
| for image_path in image_paths: | |
| if not image_path.exists(): | |
| LOGGER.error(f"Input image not found: {image_path}") | |
| all_passed = False | |
| continue | |
| LOGGER.info(f"\n--- Validating with {image_path.name} ---") | |
| # Run validation for this image and collect detailed results | |
| image_results = validate_with_single_image_detailed( | |
| mlmodel, pytorch_model, image_path, input_shape, quat_validator | |
| ) | |
| # Add image name to each result | |
| for result in image_results: | |
| result["image"] = image_path.name | |
| all_validation_results.append(result) | |
| # Check if any results failed | |
| if not all(r["passed"] for r in image_results): | |
| all_passed = False | |
| # Output combined summary table with all images and outputs | |
| LOGGER.info("\n" + "=" * 60) | |
| LOGGER.info("### Multi-Image Validation Summary") | |
| LOGGER.info("=" * 60 + "\n") | |
| # Generate combined table | |
| if all_validation_results: | |
| table = format_validation_table(all_validation_results, "", include_image_column=True) | |
| LOGGER.info(table) | |
| LOGGER.info("") | |
| return all_passed | |
| def validate_with_single_image_detailed( | |
| mlmodel: ct.models.MLModel, | |
| pytorch_model: RGBGaussianPredictor, | |
| image_path: Path, | |
| input_shape: tuple[int, int], | |
| quat_validator: QuaternionValidator | None = None, | |
| ) -> list[dict]: | |
| """Validate with a single image and return detailed results. | |
| Args: | |
| mlmodel: The Core ML model to validate. | |
| pytorch_model: The original PyTorch model. | |
| image_path: Path to the input image file. | |
| input_shape: Expected input image shape. | |
| quat_validator: Optional QuaternionValidator instance. | |
| Returns: | |
| List of validation result dictionaries. | |
| """ | |
| # Load and preprocess the input image with focal length | |
| test_image, f_px, (orig_width, orig_height) = load_and_preprocess_image(image_path, input_shape) | |
| # Compute disparity_factor as focal_length / width (matching predict.py) | |
| disparity_factor = f_px / orig_width | |
| LOGGER.info(f"Using disparity_factor = {disparity_factor:.6f} (f_px={f_px:.2f} / width={orig_width})") | |
| # Run inference on both models | |
| pt_outputs, coreml_outputs = run_inference_pair( | |
| pytorch_model, mlmodel, test_image, | |
| disparity_factor=disparity_factor, | |
| log_internals=True | |
| ) | |
| # Log depth/position statistics for debugging | |
| pt_positions = pt_outputs[0] | |
| coreml_key = find_coreml_output_key("mean_vectors_3d_positions", coreml_outputs) | |
| coreml_positions = coreml_outputs[coreml_key] | |
| # Detailed position analysis | |
| LOGGER.info(f"=== Depth/Position Statistics ({image_path.name}) ===") | |
| LOGGER.info(f"PyTorch positions - Z range: [{pt_positions[..., 2].min():.4f}, {pt_positions[..., 2].max():.4f}], mean: {pt_positions[..., 2].mean():.4f}") | |
| LOGGER.info(f"CoreML positions - Z range: [{coreml_positions[..., 2].min():.4f}, {coreml_positions[..., 2].max():.4f}], mean: {coreml_positions[..., 2].mean():.4f}") | |
| # Analyze position differences | |
| pos_diff = np.abs(pt_positions - coreml_positions) | |
| LOGGER.info(f"Position difference (X,Y,Z) - max: [{pos_diff[..., 0].max():.6f}, {pos_diff[..., 1].max():.6f}, {pos_diff[..., 2].max():.6f}]") | |
| LOGGER.info(f"Position difference (X,Y,Z) - mean: [{pos_diff[..., 0].mean():.6f}, {pos_diff[..., 1].mean():.6f}, {pos_diff[..., 2].mean():.6f}]") | |
| # Check if error is proportional to depth (would indicate global_scale issue) | |
| z_diff = np.abs(pt_positions[..., 2] - coreml_positions[..., 2]) | |
| z_ratio = z_diff / np.clip(pt_positions[..., 2], 1e-6, None) | |
| LOGGER.info(f"Z relative error - mean: {z_ratio.mean()*100:.4f}%, max: {z_ratio.max()*100:.4f}%") | |
| # Log scales for comparison | |
| pt_scales = pt_outputs[1] | |
| coreml_scales_key = find_coreml_output_key("singular_values_scales", coreml_outputs) | |
| coreml_scales = coreml_outputs[coreml_scales_key] | |
| scales_diff = np.abs(pt_scales - coreml_scales) | |
| scales_ratio = scales_diff / np.clip(pt_scales, 1e-6, None) | |
| LOGGER.info(f"Scales relative error - mean: {scales_ratio.mean()*100:.4f}%, max: {scales_ratio.max()*100:.4f}%") | |
| # Tolerances for real image validation | |
| tolerance_config = ToleranceConfig() | |
| tolerances = tolerance_config.image_tolerances | |
| # Use provided validator or create default with image tolerances | |
| if quat_validator is None: | |
| quat_validator = QuaternionValidator( | |
| angular_tolerances=tolerance_config.angular_tolerances_image | |
| ) | |
| # Compare outputs | |
| validation_results = compare_outputs( | |
| pt_outputs, | |
| coreml_outputs, | |
| tolerances, | |
| quat_validator, | |
| image_name=image_path.name | |
| ) | |
| return validation_results | |
| def validate_with_single_image( | |
| mlmodel: ct.models.MLModel, | |
| pytorch_model: RGBGaussianPredictor, | |
| image_path: Path, | |
| input_shape: tuple[int, int], | |
| quat_validator: QuaternionValidator | None = None, | |
| ) -> bool: | |
| """Validate with a single image using the new QuaternionValidator. | |
| Args: | |
| mlmodel: The Core ML model to validate. | |
| pytorch_model: The original PyTorch model. | |
| image_path: Path to the input image file. | |
| input_shape: Expected input image shape. | |
| quat_validator: Optional QuaternionValidator instance. | |
| Returns: | |
| True if validation passes, False otherwise. | |
| """ | |
| # Load and preprocess the input image | |
| test_image = load_and_preprocess_image(image_path, input_shape) | |
| test_disparity = np.array([1.0], dtype=np.float32) | |
| # Run PyTorch model | |
| traceable_wrapper = SharpModelTraceable(pytorch_model) | |
| traceable_wrapper.eval() | |
| with torch.no_grad(): | |
| pt_outputs = traceable_wrapper(test_image, torch.from_numpy(test_disparity)) | |
| # Run Core ML model | |
| test_image_np = test_image.numpy() | |
| coreml_inputs = { | |
| "image": test_image_np, | |
| "disparity_factor": test_disparity, | |
| } | |
| coreml_outputs = mlmodel.predict(coreml_inputs) | |
| # Output configuration | |
| output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"] | |
| # Tolerances for real image validation | |
| tolerances = { | |
| "mean_vectors_3d_positions": 1.2, | |
| "singular_values_scales": 0.01, | |
| "colors_rgb_linear": 0.01, | |
| "opacities_alpha_channel": 0.05, | |
| "quaternions_rotations": 5.0, | |
| } | |
| # Use provided validator or create default | |
| if quat_validator is None: | |
| quat_validator = QuaternionValidator() | |
| # Log input image statistics | |
| LOGGER.info(f"Image: {image_path.name}, shape: {test_image.shape}, range: [{test_image.min():.4f}, {test_image.max():.4f}]") | |
| # Collect validation results | |
| all_passed = True | |
| validation_results = [] | |
| for i, name in enumerate(output_names): | |
| pt_output = pt_outputs[i].numpy() | |
| # Find matching Core ML output | |
| coreml_key = None | |
| if name in coreml_outputs: | |
| coreml_key = name | |
| else: | |
| for key in coreml_outputs: | |
| base_name = name.split('_')[0] | |
| if base_name in key.lower(): | |
| coreml_key = key | |
| break | |
| if coreml_key is None: | |
| coreml_key = list(coreml_outputs.keys())[i] | |
| coreml_output = coreml_outputs[coreml_key] | |
| result = {"output": name, "passed": True, "failure_reason": ""} | |
| if name == "quaternions_rotations": | |
| # Use QuaternionValidator | |
| quat_result = quat_validator.validate(pt_output, coreml_output, image_name=image_path.name) | |
| result.update({ | |
| "max_diff": f"{quat_result['stats']['max']:.6f}", | |
| "mean_diff": f"{quat_result['stats']['mean']:.6f}", | |
| "p99_diff": f"{quat_result['stats']['p99']:.6f}", | |
| "passed": quat_result["passed"], | |
| "failure_reason": "; ".join(quat_result["failure_reasons"]) if quat_result["failure_reasons"] else "", | |
| }) | |
| if not quat_result["passed"]: | |
| all_passed = False | |
| else: | |
| diff = np.abs(pt_output - coreml_output) | |
| output_tolerance = tolerances.get(name, 0.01) | |
| max_diff = np.max(diff) | |
| result.update({ | |
| "max_diff": f"{max_diff:.6f}", | |
| "mean_diff": f"{np.mean(diff):.6f}", | |
| "p99_diff": f"{np.percentile(diff, 99):.6f}", | |
| }) | |
| if max_diff > output_tolerance: | |
| result["passed"] = False | |
| result["failure_reason"] = f"max diff {max_diff:.6f} > tolerance {output_tolerance:.6f}" | |
| all_passed = False | |
| validation_results.append(result) | |
| # Output validation results as markdown table | |
| LOGGER.info(f"\n### Validation Results: {image_path.name}\n") | |
| table = format_validation_table(validation_results, image_path.name, include_image_column=False) | |
| LOGGER.info(table) | |
| LOGGER.info("") | |
| return all_passed | |
| def main(): | |
| """Main conversion script.""" | |
| parser = argparse.ArgumentParser( | |
| description="Convert SHARP PyTorch model to Core ML format" | |
| ) | |
| parser.add_argument( | |
| "-c", "--checkpoint", | |
| type=Path, | |
| default=None, | |
| help="Path to PyTorch checkpoint. Downloads default if not provided.", | |
| ) | |
| parser.add_argument( | |
| "-o", "--output", | |
| type=Path, | |
| default=Path("sharp.mlpackage"), | |
| help="Output path for Core ML model (default: sharp.mlpackage)", | |
| ) | |
| parser.add_argument( | |
| "--height", | |
| type=int, | |
| default=1536, | |
| help="Input image height (default: 1536)", | |
| ) | |
| parser.add_argument( | |
| "--width", | |
| type=int, | |
| default=1536, | |
| help="Input image width (default: 1536)", | |
| ) | |
| parser.add_argument( | |
| "--precision", | |
| choices=["float16", "float32"], | |
| default="float32", | |
| help="Compute precision (default: float32)", | |
| ) | |
| parser.add_argument( | |
| "--validate", | |
| action="store_true", | |
| help="Validate Core ML model against PyTorch", | |
| ) | |
| parser.add_argument( | |
| "-v", "--verbose", | |
| action="store_true", | |
| help="Enable verbose logging", | |
| ) | |
| parser.add_argument( | |
| "--input-image", | |
| type=Path, | |
| default=None, | |
| action="append", | |
| help="Path to input image for validation (can be specified multiple times, requires --validate)", | |
| ) | |
| parser.add_argument( | |
| "--tolerance-mean", | |
| type=float, | |
| default=None, | |
| help="Custom mean angular tolerance in degrees (default: 0.01 for random, 0.1 for images)", | |
| ) | |
| parser.add_argument( | |
| "--tolerance-p99", | |
| type=float, | |
| default=None, | |
| help="Custom P99 angular tolerance in degrees (default: 0.5 for random, 1.0 for images)", | |
| ) | |
| parser.add_argument( | |
| "--tolerance-max", | |
| type=float, | |
| default=None, | |
| help="Custom max angular tolerance in degrees (default: 15.0)", | |
| ) | |
| args = parser.parse_args() | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.DEBUG if args.verbose else logging.INFO, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| ) | |
| # Load PyTorch model | |
| LOGGER.info("Loading SHARP model...") | |
| predictor = load_sharp_model(args.checkpoint) | |
| # Setup conversion parameters | |
| input_shape = (args.height, args.width) | |
| precision = ct.precision.FLOAT16 if args.precision == "float16" else ct.precision.FLOAT32 | |
| # Convert to Core ML | |
| LOGGER.info("Converting using direct tracing...") | |
| mlmodel = convert_to_coreml( | |
| predictor, | |
| args.output, | |
| input_shape=input_shape, | |
| compute_precision=precision, | |
| ) | |
| LOGGER.info(f"Core ML model saved to {args.output}") | |
| # Validate if requested | |
| if args.validate: | |
| if args.input_image: | |
| # Validate with one or more real input images | |
| validation_passed = validate_with_image_set(mlmodel, predictor, args.input_image, input_shape) | |
| else: | |
| # Validate with random input (default behavior) | |
| # Build custom angular tolerances from CLI args | |
| angular_tolerances = None | |
| if args.tolerance_mean or args.tolerance_p99 or args.tolerance_max: | |
| angular_tolerances = { | |
| "mean": args.tolerance_mean if args.tolerance_mean else 0.01, | |
| "p99": args.tolerance_p99 if args.tolerance_p99 else 0.5, | |
| "p99_9": 2.0, | |
| "max": args.tolerance_max if args.tolerance_max else 15.0, | |
| } | |
| validation_passed = validate_coreml_model(mlmodel, predictor, input_shape, angular_tolerances=angular_tolerances) | |
| if validation_passed: | |
| LOGGER.info("✓ Validation passed!") | |
| else: | |
| LOGGER.error("✗ Validation failed!") | |
| return 1 | |
| LOGGER.info("Conversion complete!") | |
| return 0 | |
| if __name__ == "__main__": | |
| exit(main()) | |
| exit(main()) | |