Kyle Pearson commited on
Commit
1df9b82
·
1 Parent(s): 1aecf6d

Add quaternion validator with configurable tolerances, enhance coreml validation with detailed stats, add PIL-based image preprocessing, introduce batch/image-specific validation functions, update CLI flags for input/output handling.

Browse files
convert.py CHANGED
@@ -15,6 +15,7 @@ import coremltools as ct
15
  import numpy as np
16
  import torch
17
  import torch.nn as nn
 
18
 
19
  # Import SHARP model components
20
  from sharp.models import PredictorParams, create_predictor
@@ -481,11 +482,176 @@ def convert_to_coreml_with_preprocessing(
481
  return mlmodel
482
 
483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
  def validate_coreml_model(
485
  mlmodel: ct.models.MLModel,
486
  pytorch_model: RGBGaussianPredictor,
487
  input_shape: tuple[int, int] = (1536, 1536),
488
  tolerance: float = 0.01,
 
489
  ) -> bool:
490
  """Validate Core ML model outputs against PyTorch model.
491
 
@@ -494,6 +660,7 @@ def validate_coreml_model(
494
  pytorch_model: The original PyTorch model.
495
  input_shape: Input image shape (height, width).
496
  tolerance: Maximum allowed difference between outputs.
 
497
 
498
  Returns:
499
  True if validation passes, False otherwise.
@@ -527,14 +694,13 @@ def validate_coreml_model(
527
  }
528
  coreml_outputs = mlmodel.predict(coreml_inputs)
529
 
530
- # Debug: Print shapes and keys
531
  LOGGER.info(f"PyTorch outputs shapes: {[o.shape for o in pt_outputs]}")
532
  LOGGER.info(f"Core ML outputs keys: {list(coreml_outputs.keys())}")
533
 
534
- # Compare outputs with per-output tolerances
535
  output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"]
536
 
537
- # Define tighter tolerances per output type
538
  tolerances = {
539
  "mean_vectors_3d_positions": 0.001,
540
  "singular_values_scales": 0.0001,
@@ -543,12 +709,17 @@ def validate_coreml_model(
543
  "opacities_alpha_channel": 0.005,
544
  }
545
 
546
- # Angular tolerances for quaternions (in degrees)
547
- angular_tolerances = {
548
- "mean": 0.01,
549
- "p99": 0.5,
550
- "max": 10.0,
551
- }
 
 
 
 
 
552
 
553
  all_passed = True
554
 
@@ -565,7 +736,7 @@ def validate_coreml_model(
565
  LOGGER.info(f"Z-coordinate difference - max: {z_diff.max():.6f}, mean: {z_diff.mean():.6f}, std: {z_diff.std():.6f}")
566
  LOGGER.info("=================================")
567
 
568
- # Collect validation results for table output
569
  validation_results = []
570
 
571
  for i, name in enumerate(output_names):
@@ -588,14 +759,233 @@ def validate_coreml_model(
588
  coreml_output = coreml_outputs[coreml_key]
589
  result = {"output": name, "passed": True, "failure_reason": ""}
590
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
591
  # Special handling for quaternions
592
  if name == "quaternions_rotations":
593
  pt_quat_norm = np.linalg.norm(pt_output, axis=-1, keepdims=True)
594
  pt_output_normalized = pt_output / np.clip(pt_quat_norm, 1e-12, None)
595
-
596
  coreml_quat_norm = np.linalg.norm(coreml_output, axis=-1, keepdims=True)
597
  coreml_output_normalized = coreml_output / np.clip(coreml_quat_norm, 1e-12, None)
598
-
599
  def canonicalize_quaternion(q):
600
  abs_q = np.abs(q)
601
  max_component_idx = np.argmax(abs_q, axis=-1, keepdims=True)
@@ -603,22 +993,30 @@ def validate_coreml_model(
603
  np.put_along_axis(selector, max_component_idx, 1, axis=-1)
604
  max_component_sign = np.sum(q * selector, axis=-1, keepdims=True)
605
  return np.where(max_component_sign < 0, -q, q)
606
-
607
  pt_output_canonical = canonicalize_quaternion(pt_output_normalized)
608
  coreml_output_canonical = canonicalize_quaternion(coreml_output_normalized)
609
-
610
  diff = np.abs(pt_output_canonical - coreml_output_canonical)
611
  dot_products = np.sum(pt_output_canonical * coreml_output_canonical, axis=-1)
612
- dot_products = np.clip(np.abs(dot_products), 0.0, 1.0)
 
 
 
 
 
 
 
 
613
  angular_diff_rad = 2 * np.arccos(dot_products)
614
  angular_diff_deg = np.degrees(angular_diff_rad)
615
  max_angular = np.max(angular_diff_deg)
616
  mean_angular = np.mean(angular_diff_deg)
617
  p99_angular = np.percentile(angular_diff_deg, 99)
618
-
619
  quat_passed = True
620
  failure_reasons = []
621
-
622
  if mean_angular > angular_tolerances["mean"]:
623
  quat_passed = False
624
  failure_reasons.append(f"mean angular {mean_angular:.4f}° > {angular_tolerances['mean']:.4f}°")
@@ -628,7 +1026,7 @@ def validate_coreml_model(
628
  if max_angular > angular_tolerances["max"]:
629
  quat_passed = False
630
  failure_reasons.append(f"max angular {max_angular:.4f}° > {angular_tolerances['max']:.4f}°")
631
-
632
  result.update({
633
  "max_diff": f"{np.max(diff):.6f}",
634
  "mean_diff": f"{np.mean(diff):.6f}",
@@ -643,7 +1041,7 @@ def validate_coreml_model(
643
  all_passed = False
644
  else:
645
  diff = np.abs(pt_output - coreml_output)
646
- output_tolerance = tolerances.get(name, tolerance)
647
  result.update({
648
  "max_diff": f"{np.max(diff):.6f}",
649
  "mean_diff": f"{np.mean(diff):.6f}",
@@ -654,24 +1052,197 @@ def validate_coreml_model(
654
  result["passed"] = False
655
  result["failure_reason"] = f"max diff {np.max(diff):.6f} > tolerance {output_tolerance:.6f}"
656
  all_passed = False
657
-
658
  validation_results.append(result)
659
-
660
  # Output validation results as markdown table
661
- if validation_results:
662
- LOGGER.info("\n### Validation Results\n")
663
- LOGGER.info("| Output | Max Diff | Mean Diff | P99 Diff | Angular Diff (°) | Status |")
664
- LOGGER.info("|--------|----------|-----------|----------|------------------|--------|")
665
-
666
- for result in validation_results:
667
- output_name = result["output"].replace("_", " ").title()
668
- if "max_angular" in result:
669
- angular_info = f"{result['max_angular']} / {result['mean_angular']} / {result['p99_angular']}"
670
- else:
671
- angular_info = "-"
672
- status = "✅ PASS" if result["passed"] else f"❌ FAIL"
673
- LOGGER.info(f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | {result['p99_diff']} | {angular_info} | {status} |")
674
- LOGGER.info("")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
675
 
676
  return all_passed
677
 
@@ -726,6 +1297,31 @@ def main():
726
  action="store_true",
727
  help="Enable verbose logging",
728
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
729
 
730
  args = parser.parse_args()
731
 
@@ -764,7 +1360,21 @@ def main():
764
 
765
  # Validate if requested
766
  if args.validate:
767
- validation_passed = validate_coreml_model(mlmodel, predictor, input_shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
768
 
769
  if validation_passed:
770
  LOGGER.info("✓ Validation passed!")
 
15
  import numpy as np
16
  import torch
17
  import torch.nn as nn
18
+ from PIL import Image
19
 
20
  # Import SHARP model components
21
  from sharp.models import PredictorParams, create_predictor
 
482
  return mlmodel
483
 
484
 
485
+ class QuaternionValidator:
486
+ """Validator for quaternion comparisons with configurable tolerances and outlier analysis."""
487
+
488
+ DEFAULT_ANGULAR_TOLERANCES = {
489
+ "mean": 0.01,
490
+ "p99": 0.5,
491
+ "p99_9": 2.0,
492
+ "max": 15.0,
493
+ }
494
+
495
+ def __init__(
496
+ self,
497
+ angular_tolerances: dict[str, float] | None = None,
498
+ enable_outlier_analysis: bool = True,
499
+ outlier_thresholds: list[float] | None = None,
500
+ ):
501
+ """Initialize validator with tolerances.
502
+
503
+ Args:
504
+ angular_tolerances: Dict with keys 'mean', 'p99', 'p99_9', 'max' for angular diffs in degrees.
505
+ enable_outlier_analysis: Whether to perform detailed outlier analysis.
506
+ outlier_thresholds: List of angle thresholds for outlier counting.
507
+ """
508
+ self.angular_tolerances = angular_tolerances or self.DEFAULT_ANGULAR_TOLERANCES.copy()
509
+ self.enable_outlier_analysis = enable_outlier_analysis
510
+ self.outlier_thresholds = outlier_thresholds or [5.0, 10.0, 15.0]
511
+
512
+ @staticmethod
513
+ def canonicalize_quaternion(q: np.ndarray) -> np.ndarray:
514
+ """Canonicalize quaternion to ensure consistent representation.
515
+
516
+ Ensures the quaternion with the largest absolute component is positive.
517
+ This handles the sign ambiguity where q and -q represent the same rotation.
518
+
519
+ Args:
520
+ q: Quaternion array of shape (..., 4)
521
+
522
+ Returns:
523
+ Canonicalized quaternion array.
524
+ """
525
+ abs_q = np.abs(q)
526
+ max_component_idx = np.argmax(abs_q, axis=-1, keepdims=True)
527
+ selector = np.zeros_like(q)
528
+ np.put_along_axis(selector, max_component_idx, 1.0, axis=-1)
529
+ max_component_sign = np.sum(q * selector, axis=-1, keepdims=True)
530
+ return np.where(max_component_sign < 0, -q, q)
531
+
532
+ @staticmethod
533
+ def compute_angular_differences(
534
+ quats1: np.ndarray, quats2: np.ndarray
535
+ ) -> tuple[np.ndarray, dict[str, float]]:
536
+ """Compute angular differences between two sets of quaternions.
537
+
538
+ Args:
539
+ quats1: First set of quaternions shape (N, 4)
540
+ quats2: Second set of quaternions shape (N, 4)
541
+
542
+ Returns:
543
+ Tuple of (angular_differences in degrees, statistics dict)
544
+ """
545
+ # Normalize quaternions
546
+ norm1 = np.linalg.norm(quats1, axis=-1, keepdims=True)
547
+ norm2 = np.linalg.norm(quats2, axis=-1, keepdims=True)
548
+ quats1_norm = quats1 / np.clip(norm1, 1e-12, None)
549
+ quats2_norm = quats2 / np.clip(norm2, 1e-12, None)
550
+
551
+ # Canonicalize both
552
+ quats1_canon = QuaternionValidator.canonicalize_quaternion(quats1_norm)
553
+ quats2_canon = QuaternionValidator.canonicalize_quaternion(quats2_norm)
554
+
555
+ # Compute dot products for both q·q and q·(-q) to handle sign ambiguity
556
+ dot_products = np.sum(quats1_canon * quats2_canon, axis=-1)
557
+ dot_products_flipped = np.sum(quats1_canon * (-quats2_canon), axis=-1)
558
+
559
+ # Take the maximum absolute dot product (handle sign ambiguity)
560
+ dot_products = np.maximum(np.abs(dot_products), np.abs(dot_products_flipped))
561
+ dot_products = np.clip(dot_products, 0.0, 1.0)
562
+
563
+ # Compute angular differences
564
+ angular_diff_rad = 2.0 * np.arccos(dot_products)
565
+ angular_diff_deg = np.degrees(angular_diff_rad)
566
+
567
+ # Compute statistics
568
+ stats = {
569
+ "mean": float(np.mean(angular_diff_deg)),
570
+ "std": float(np.std(angular_diff_deg)),
571
+ "min": float(np.min(angular_diff_deg)),
572
+ "max": float(np.max(angular_diff_deg)),
573
+ "p50": float(np.percentile(angular_diff_deg, 50)),
574
+ "p90": float(np.percentile(angular_diff_deg, 90)),
575
+ "p99": float(np.percentile(angular_diff_deg, 99)),
576
+ "p99_9": float(np.percentile(angular_diff_deg, 99.9)),
577
+ }
578
+
579
+ return angular_diff_deg, stats
580
+
581
+ def analyze_outliers(
582
+ self, angular_diff_deg: np.ndarray
583
+ ) -> dict[str, dict[str, int | float]]:
584
+ """Analyze outliers in angular differences.
585
+
586
+ Args:
587
+ angular_diff_deg: Array of angular differences in degrees.
588
+
589
+ Returns:
590
+ Dict with outlier statistics for each threshold.
591
+ """
592
+ if not self.enable_outlier_analysis:
593
+ return {}
594
+
595
+ outlier_stats = {}
596
+ total = len(angular_diff_deg)
597
+
598
+ for threshold in self.outlier_thresholds:
599
+ count = int(np.sum(angular_diff_deg > threshold))
600
+ outlier_stats[f">{threshold}°"] = {
601
+ "count": count,
602
+ "percentage": (count / total) * 100.0 if total > 0 else 0.0,
603
+ }
604
+
605
+ return outlier_stats
606
+
607
+ def validate(
608
+ self,
609
+ pt_quaternions: np.ndarray,
610
+ coreml_quaternions: np.ndarray,
611
+ image_name: str = "Unknown",
612
+ ) -> dict:
613
+ """Validate Core ML quaternions against PyTorch quaternions.
614
+
615
+ Args:
616
+ pt_quaternions: PyTorch quaternion outputs.
617
+ coreml_quaternions: Core ML quaternion outputs.
618
+ image_name: Name of the image being validated.
619
+
620
+ Returns:
621
+ Dict with validation results including status, stats, and outliers.
622
+ """
623
+ angular_diff_deg, stats = self.compute_angular_differences(
624
+ pt_quaternions, coreml_quaternions
625
+ )
626
+ outlier_stats = self.analyze_outliers(angular_diff_deg)
627
+
628
+ # Check tolerances
629
+ passed = True
630
+ failure_reasons = []
631
+
632
+ for key, tolerance in self.angular_tolerances.items():
633
+ if key in stats and stats[key] > tolerance:
634
+ passed = False
635
+ failure_reasons.append(
636
+ f"{key} angular {stats[key]:.4f}° > tolerance {tolerance:.4f}°"
637
+ )
638
+
639
+ return {
640
+ "image": image_name,
641
+ "passed": passed,
642
+ "failure_reasons": failure_reasons,
643
+ "stats": stats,
644
+ "outliers": outlier_stats,
645
+ "num_gaussians": len(angular_diff_deg),
646
+ }
647
+
648
+
649
  def validate_coreml_model(
650
  mlmodel: ct.models.MLModel,
651
  pytorch_model: RGBGaussianPredictor,
652
  input_shape: tuple[int, int] = (1536, 1536),
653
  tolerance: float = 0.01,
654
+ angular_tolerances: dict[str, float] | None = None,
655
  ) -> bool:
656
  """Validate Core ML model outputs against PyTorch model.
657
 
 
660
  pytorch_model: The original PyTorch model.
661
  input_shape: Input image shape (height, width).
662
  tolerance: Maximum allowed difference between outputs.
663
+ angular_tolerances: Dict with keys 'mean', 'p99', 'p99_9', 'max' for angular diffs in degrees.
664
 
665
  Returns:
666
  True if validation passes, False otherwise.
 
694
  }
695
  coreml_outputs = mlmodel.predict(coreml_inputs)
696
 
 
697
  LOGGER.info(f"PyTorch outputs shapes: {[o.shape for o in pt_outputs]}")
698
  LOGGER.info(f"Core ML outputs keys: {list(coreml_outputs.keys())}")
699
 
700
+ # Output configuration
701
  output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"]
702
 
703
+ # Define tolerances per output type
704
  tolerances = {
705
  "mean_vectors_3d_positions": 0.001,
706
  "singular_values_scales": 0.0001,
 
709
  "opacities_alpha_channel": 0.005,
710
  }
711
 
712
+ # Use provided angular tolerances or defaults
713
+ if angular_tolerances is None:
714
+ angular_tolerances = {
715
+ "mean": 0.01,
716
+ "p99": 0.1,
717
+ "p99_9": 1.0,
718
+ "max": 5.0,
719
+ }
720
+
721
+ # Initialize quaternion validator
722
+ quat_validator = QuaternionValidator(angular_tolerances=angular_tolerances)
723
 
724
  all_passed = True
725
 
 
736
  LOGGER.info(f"Z-coordinate difference - max: {z_diff.max():.6f}, mean: {z_diff.mean():.6f}, std: {z_diff.std():.6f}")
737
  LOGGER.info("=================================")
738
 
739
+ # Collect validation results
740
  validation_results = []
741
 
742
  for i, name in enumerate(output_names):
 
759
  coreml_output = coreml_outputs[coreml_key]
760
  result = {"output": name, "passed": True, "failure_reason": ""}
761
 
762
+ # Special handling for quaternions
763
+ if name == "quaternions_rotations":
764
+ # Use the new QuaternionValidator
765
+ quat_result = quat_validator.validate(pt_output, coreml_output, image_name="Random")
766
+
767
+ result.update({
768
+ "max_diff": f"{quat_result['stats']['max']:.6f}",
769
+ "mean_diff": f"{quat_result['stats']['mean']:.6f}",
770
+ "p99_diff": f"{quat_result['stats']['p99']:.6f}",
771
+ "p99_9_diff": f"{quat_result['stats']['p99_9']:.6f}",
772
+ "max_angular": f"{quat_result['stats']['max']:.4f}",
773
+ "mean_angular": f"{quat_result['stats']['mean']:.4f}",
774
+ "p99_angular": f"{quat_result['stats']['p99']:.4f}",
775
+ "passed": quat_result["passed"],
776
+ "failure_reason": "; ".join(quat_result["failure_reasons"]) if quat_result["failure_reasons"] else "",
777
+ "quat_stats": quat_result["stats"],
778
+ "outliers": quat_result["outliers"],
779
+ })
780
+ if not quat_result["passed"]:
781
+ all_passed = False
782
+ else:
783
+ diff = np.abs(pt_output - coreml_output)
784
+ output_tolerance = tolerances.get(name, tolerance)
785
+ result.update({
786
+ "max_diff": f"{np.max(diff):.6f}",
787
+ "mean_diff": f"{np.mean(diff):.6f}",
788
+ "p99_diff": f"{np.percentile(diff, 99):.6f}",
789
+ "tolerance": f"{output_tolerance:.6f}"
790
+ })
791
+ if np.max(diff) > output_tolerance:
792
+ result["passed"] = False
793
+ result["failure_reason"] = f"max diff {np.max(diff):.6f} > tolerance {output_tolerance:.6f}"
794
+ all_passed = False
795
+
796
+ validation_results.append(result)
797
+
798
+ # Output validation results as markdown table
799
+ LOGGER.info("\n### Validation Results\n")
800
+ LOGGER.info("| Output | Max Diff | Mean Diff | P99 Diff | P99.9 Diff | Angular Diff (°) | Status |")
801
+ LOGGER.info("|--------|----------|-----------|----------|------------|------------------|--------|")
802
+
803
+ for result in validation_results:
804
+ output_name = result["output"].replace("_", " ").title()
805
+ if "max_angular" in result:
806
+ angular_info = f"{result['max_angular']} / {result['mean_angular']} / {result['p99_angular']}"
807
+ p99_9 = result.get("p99_9_diff", "-")
808
+ status = "✅ PASS" if result["passed"] else f"❌ FAIL"
809
+ LOGGER.info(f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | {result['p99_diff']} | {p99_9} | {angular_info} | {status} |")
810
+ else:
811
+ status = "✅ PASS" if result["passed"] else f"❌ FAIL"
812
+ LOGGER.info(f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | {result['p99_diff']} | - | - | {status} |")
813
+ LOGGER.info("")
814
+
815
+ # Output quaternion outlier analysis if available
816
+ for result in validation_results:
817
+ if "outliers" in result and result["outliers"]:
818
+ LOGGER.info("### Quaternion Outlier Analysis\n")
819
+ LOGGER.info(f"| Threshold | Count | Percentage |")
820
+ LOGGER.info("|-----------|-------|------------|")
821
+ for threshold, data in result["outliers"].items():
822
+ LOGGER.info(f"| {threshold} | {data['count']} | {data['percentage']:.4f}% |")
823
+ LOGGER.info("")
824
+
825
+ return all_passed
826
+
827
+
828
+ def load_and_preprocess_image(
829
+ image_path: Path,
830
+ target_size: tuple[int, int] = (1536, 1536),
831
+ ) -> torch.Tensor:
832
+ """Load and preprocess an input image for SHARP inference.
833
+
834
+ Args:
835
+ image_path: Path to the input image file.
836
+ target_size: Target (height, width) for resizing.
837
+
838
+ Returns:
839
+ Preprocessed image tensor of shape (1, 3, H, W) in range [0, 1].
840
+ """
841
+ LOGGER.info(f"Loading image from {image_path}")
842
+
843
+ # Load image using PIL
844
+ image = Image.open(image_path)
845
+
846
+ # Convert to RGB if needed (handle grayscale or RGBA)
847
+ if image.mode != "RGB":
848
+ image = image.convert("RGB")
849
+
850
+ original_size = image.size # (width, height)
851
+ LOGGER.info(f"Original image size: {original_size}")
852
+
853
+ # Resize to target size if different
854
+ if (image.width, image.height) != target_size:
855
+ LOGGER.info(f"Resizing to {target_size[1]}x{target_size[0]}")
856
+ image = image.resize((target_size[1], target_size[0]), Image.BILINEAR)
857
+
858
+ # Convert to numpy array and normalize to [0, 1]
859
+ image_np = np.array(image, dtype=np.float32) / 255.0
860
+
861
+ # Transpose to (C, H, W) and add batch dimension
862
+ # PIL images are (W, H, C), numpy is (H, W, C)
863
+ image_np = image_np.transpose(2, 0, 1) # (3, H, W)
864
+ image_tensor = torch.from_numpy(image_np).unsqueeze(0) # (1, 3, H, W)
865
+
866
+ LOGGER.info(f"Preprocessed image shape: {image_tensor.shape}, range: [{image_tensor.min():.4f}, {image_tensor.max():.4f}]")
867
+
868
+ return image_tensor
869
+
870
+
871
+ def validate_with_image(
872
+ mlmodel: ct.models.MLModel,
873
+ pytorch_model: RGBGaussianPredictor,
874
+ image_path: Path,
875
+ input_shape: tuple[int, int] = (1536, 1536),
876
+ ) -> bool:
877
+ """Validate Core ML model outputs against PyTorch model using a real input image.
878
+
879
+ Args:
880
+ mlmodel: The Core ML model to validate.
881
+ pytorch_model: The original PyTorch model.
882
+ image_path: Path to the input image file.
883
+ input_shape: Expected input image shape (height, width).
884
+
885
+ Returns:
886
+ True if validation passes, False otherwise.
887
+ """
888
+ LOGGER.info("=" * 60)
889
+ LOGGER.info("Validating Core ML model against PyTorch with real image")
890
+ LOGGER.info("=" * 60)
891
+
892
+ # Load and preprocess the input image
893
+ test_image = load_and_preprocess_image(image_path, input_shape)
894
+ test_disparity = np.array([1.0], dtype=np.float32)
895
+
896
+ # Run PyTorch model
897
+ traceable_wrapper = SharpModelTraceable(pytorch_model)
898
+ traceable_wrapper.eval()
899
+
900
+ with torch.no_grad():
901
+ pt_outputs = traceable_wrapper(test_image, torch.from_numpy(test_disparity))
902
+
903
+ LOGGER.info(f"PyTorch outputs shapes: {[o.shape for o in pt_outputs]}")
904
+
905
+ # Run Core ML model
906
+ test_image_np = test_image.numpy()
907
+ coreml_inputs = {
908
+ "image": test_image_np,
909
+ "disparity_factor": test_disparity,
910
+ }
911
+ coreml_outputs = mlmodel.predict(coreml_inputs)
912
+
913
+ LOGGER.info(f"Core ML outputs keys: {list(coreml_outputs.keys())}")
914
+
915
+ # Output configuration
916
+ output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"]
917
+
918
+ # Define tolerances per output type for real image validation
919
+ # Using p99-based tolerances to handle outliers better
920
+ tolerances = {
921
+ "mean_vectors_3d_positions": 1.2,
922
+ "singular_values_scales": 0.01,
923
+ "quaternions_rotations": 5.0,
924
+ "colors_rgb_linear": 0.01,
925
+ "opacities_alpha_channel": 0.05,
926
+ }
927
+
928
+ # Angular tolerances for quaternions (in degrees)
929
+ angular_tolerances = {
930
+ "mean": 0.1,
931
+ "p99": 1.0,
932
+ "max": 15.0,
933
+ }
934
+
935
+ all_passed = True
936
+
937
+ # Log input image statistics
938
+ LOGGER.info(f"\n=== Input Image Statistics ===")
939
+ LOGGER.info(f"Image path: {image_path}")
940
+ LOGGER.info(f"Image shape: {test_image.shape}")
941
+ LOGGER.info(f"Image range: [{test_image.min():.4f}, {test_image.max():.4f}]")
942
+ LOGGER.info(f"Image mean: {test_image.mean(dim=[1,2,3]).tolist()}")
943
+ LOGGER.info("=" * 30)
944
+
945
+ # Depth/position analysis
946
+ pt_positions = pt_outputs[0].numpy()
947
+ coreml_key = [k for k in coreml_outputs.keys() if "mean_vectors" in k][0]
948
+ coreml_positions = coreml_outputs[coreml_key]
949
+
950
+ LOGGER.info("\n=== Depth/Position Statistics ===")
951
+ LOGGER.info(f"PyTorch positions - Z range: [{pt_positions[..., 2].min():.4f}, {pt_positions[..., 2].max():.4f}], mean: {pt_positions[..., 2].mean():.4f}, std: {pt_positions[..., 2].std():.4f}")
952
+ LOGGER.info(f"CoreML positions - Z range: [{coreml_positions[..., 2].min():.4f}, {coreml_positions[..., 2].max():.4f}], mean: {coreml_positions[..., 2].mean():.4f}, std: {coreml_positions[..., 2].std():.4f}")
953
+
954
+ z_diff = np.abs(pt_positions[..., 2] - coreml_positions[..., 2])
955
+ LOGGER.info(f"Z-coordinate difference - max: {z_diff.max():.6f}, mean: {z_diff.mean():.6f}, std: {z_diff.std():.6f}")
956
+ LOGGER.info("=================================\n")
957
+
958
+ # Collect validation results
959
+ validation_results = []
960
+
961
+ for i, name in enumerate(output_names):
962
+ pt_output = pt_outputs[i].numpy()
963
+
964
+ # Find matching Core ML output
965
+ coreml_key = None
966
+ if name in coreml_outputs:
967
+ coreml_key = name
968
+ else:
969
+ # Try partial match
970
+ for key in coreml_outputs:
971
+ base_name = name.split('_')[0]
972
+ if base_name in key.lower():
973
+ coreml_key = key
974
+ break
975
+ if coreml_key is None:
976
+ coreml_key = list(coreml_outputs.keys())[i]
977
+
978
+ coreml_output = coreml_outputs[coreml_key]
979
+ result = {"output": name, "passed": True, "failure_reason": ""}
980
+
981
  # Special handling for quaternions
982
  if name == "quaternions_rotations":
983
  pt_quat_norm = np.linalg.norm(pt_output, axis=-1, keepdims=True)
984
  pt_output_normalized = pt_output / np.clip(pt_quat_norm, 1e-12, None)
985
+
986
  coreml_quat_norm = np.linalg.norm(coreml_output, axis=-1, keepdims=True)
987
  coreml_output_normalized = coreml_output / np.clip(coreml_quat_norm, 1e-12, None)
988
+
989
  def canonicalize_quaternion(q):
990
  abs_q = np.abs(q)
991
  max_component_idx = np.argmax(abs_q, axis=-1, keepdims=True)
 
993
  np.put_along_axis(selector, max_component_idx, 1, axis=-1)
994
  max_component_sign = np.sum(q * selector, axis=-1, keepdims=True)
995
  return np.where(max_component_sign < 0, -q, q)
996
+
997
  pt_output_canonical = canonicalize_quaternion(pt_output_normalized)
998
  coreml_output_canonical = canonicalize_quaternion(coreml_output_normalized)
999
+
1000
  diff = np.abs(pt_output_canonical - coreml_output_canonical)
1001
  dot_products = np.sum(pt_output_canonical * coreml_output_canonical, axis=-1)
1002
+ dot_products_flipped = np.sum(pt_output_canonical * (-coreml_output_canonical), axis=-1)
1003
+ # Take the absolute value and ensure we compare q with -q if needed
1004
+ # This handles the sign ambiguity: q and -q represent the same rotation
1005
+ dot_products = np.where(
1006
+ np.abs(dot_products) > np.abs(dot_products_flipped),
1007
+ np.abs(dot_products),
1008
+ np.abs(dot_products_flipped)
1009
+ )
1010
+ dot_products = np.clip(dot_products, 0.0, 1.0)
1011
  angular_diff_rad = 2 * np.arccos(dot_products)
1012
  angular_diff_deg = np.degrees(angular_diff_rad)
1013
  max_angular = np.max(angular_diff_deg)
1014
  mean_angular = np.mean(angular_diff_deg)
1015
  p99_angular = np.percentile(angular_diff_deg, 99)
1016
+
1017
  quat_passed = True
1018
  failure_reasons = []
1019
+
1020
  if mean_angular > angular_tolerances["mean"]:
1021
  quat_passed = False
1022
  failure_reasons.append(f"mean angular {mean_angular:.4f}° > {angular_tolerances['mean']:.4f}°")
 
1026
  if max_angular > angular_tolerances["max"]:
1027
  quat_passed = False
1028
  failure_reasons.append(f"max angular {max_angular:.4f}° > {angular_tolerances['max']:.4f}°")
1029
+
1030
  result.update({
1031
  "max_diff": f"{np.max(diff):.6f}",
1032
  "mean_diff": f"{np.mean(diff):.6f}",
 
1041
  all_passed = False
1042
  else:
1043
  diff = np.abs(pt_output - coreml_output)
1044
+ output_tolerance = tolerances.get(name, 0.01)
1045
  result.update({
1046
  "max_diff": f"{np.max(diff):.6f}",
1047
  "mean_diff": f"{np.mean(diff):.6f}",
 
1052
  result["passed"] = False
1053
  result["failure_reason"] = f"max diff {np.max(diff):.6f} > tolerance {output_tolerance:.6f}"
1054
  all_passed = False
1055
+
1056
  validation_results.append(result)
1057
+
1058
  # Output validation results as markdown table
1059
+ LOGGER.info("\n### Image Validation Results\n")
1060
+ LOGGER.info(f"| Output | Max Diff | Mean Diff | P99 Diff | Angular Diff (°) | Status |")
1061
+ LOGGER.info(f"|--------|----------|-----------|----------|------------------|--------|")
1062
+
1063
+ for result in validation_results:
1064
+ output_name = result["output"].replace("_", " ").title()
1065
+ if "max_angular" in result:
1066
+ angular_info = f"{result['max_angular']} / {result['mean_angular']} / {result['p99_angular']}"
1067
+ else:
1068
+ angular_info = "-"
1069
+ status = "✅ PASS" if result["passed"] else f"❌ FAIL"
1070
+ LOGGER.info(f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | {result['p99_diff']} | {angular_info} | {status} |")
1071
+ LOGGER.info("")
1072
+
1073
+ return all_passed
1074
+
1075
+
1076
+ def validate_with_image_set(
1077
+ mlmodel: ct.models.MLModel,
1078
+ pytorch_model: RGBGaussianPredictor,
1079
+ image_paths: list[Path],
1080
+ input_shape: tuple[int, int] = (1536, 1536),
1081
+ ) -> bool:
1082
+ """Validate Core ML model against PyTorch using multiple input images.
1083
+
1084
+ Args:
1085
+ mlmodel: The Core ML model to validate.
1086
+ pytorch_model: The original PyTorch model.
1087
+ image_paths: List of paths to input images for validation.
1088
+ input_shape: Expected input image shape (height, width).
1089
+
1090
+ Returns:
1091
+ True if all validations pass, False otherwise.
1092
+ """
1093
+ LOGGER.info("=" * 60)
1094
+ LOGGER.info(f"Validating Core ML model with {len(image_paths)} images")
1095
+ LOGGER.info("=" * 60)
1096
+
1097
+ # Angular tolerances for image validation (more lenient than random validation)
1098
+ # Real images have more variation than random noise
1099
+ angular_tolerances = {
1100
+ "mean": 0.2,
1101
+ "p99": 2.0,
1102
+ "p99_9": 5.0,
1103
+ "max": 25.0,
1104
+ }
1105
+
1106
+ # Initialize quaternion validator
1107
+ quat_validator = QuaternionValidator(angular_tolerances=angular_tolerances)
1108
+
1109
+ all_passed = True
1110
+ image_results = []
1111
+
1112
+ for image_path in image_paths:
1113
+ if not image_path.exists():
1114
+ LOGGER.error(f"Input image not found: {image_path}")
1115
+ all_passed = False
1116
+ continue
1117
+
1118
+ LOGGER.info(f"\n--- Validating with {image_path.name} ---")
1119
+
1120
+ # Run validation for this image
1121
+ passed = validate_with_single_image(
1122
+ mlmodel, pytorch_model, image_path, input_shape, quat_validator
1123
+ )
1124
+ image_results.append({"image": image_path.name, "passed": passed})
1125
+
1126
+ if not passed:
1127
+ all_passed = False
1128
+
1129
+ # Output summary table
1130
+ LOGGER.info("\n" + "=" * 60)
1131
+ LOGGER.info("### Multi-Image Validation Summary")
1132
+ LOGGER.info("=" * 60)
1133
+ LOGGER.info(f"| Image | Status |")
1134
+ LOGGER.info("|-------|--------|")
1135
+
1136
+ for result in image_results:
1137
+ status = "✅ PASS" if result["passed"] else "❌ FAIL"
1138
+ LOGGER.info(f"| {result['image']} | {status} |")
1139
+
1140
+ LOGGER.info("")
1141
+
1142
+ return all_passed
1143
+
1144
+
1145
+ def validate_with_single_image(
1146
+ mlmodel: ct.models.MLModel,
1147
+ pytorch_model: RGBGaussianPredictor,
1148
+ image_path: Path,
1149
+ input_shape: tuple[int, int],
1150
+ quat_validator: QuaternionValidator | None = None,
1151
+ ) -> bool:
1152
+ """Validate with a single image using the new QuaternionValidator.
1153
+
1154
+ Args:
1155
+ mlmodel: The Core ML model to validate.
1156
+ pytorch_model: The original PyTorch model.
1157
+ image_path: Path to the input image file.
1158
+ input_shape: Expected input image shape.
1159
+ quat_validator: Optional QuaternionValidator instance.
1160
+
1161
+ Returns:
1162
+ True if validation passes, False otherwise.
1163
+ """
1164
+ # Load and preprocess the input image
1165
+ test_image = load_and_preprocess_image(image_path, input_shape)
1166
+ test_disparity = np.array([1.0], dtype=np.float32)
1167
+
1168
+ # Run PyTorch model
1169
+ traceable_wrapper = SharpModelTraceable(pytorch_model)
1170
+ traceable_wrapper.eval()
1171
+
1172
+ with torch.no_grad():
1173
+ pt_outputs = traceable_wrapper(test_image, torch.from_numpy(test_disparity))
1174
+
1175
+ # Run Core ML model
1176
+ test_image_np = test_image.numpy()
1177
+ coreml_inputs = {
1178
+ "image": test_image_np,
1179
+ "disparity_factor": test_disparity,
1180
+ }
1181
+ coreml_outputs = mlmodel.predict(coreml_inputs)
1182
+
1183
+ # Output configuration
1184
+ output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"]
1185
+
1186
+ # Tolerances for real image validation
1187
+ tolerances = {
1188
+ "mean_vectors_3d_positions": 1.2,
1189
+ "singular_values_scales": 0.01,
1190
+ "colors_rgb_linear": 0.01,
1191
+ "opacities_alpha_channel": 0.05,
1192
+ }
1193
+
1194
+ # Use provided validator or create default
1195
+ if quat_validator is None:
1196
+ quat_validator = QuaternionValidator()
1197
+
1198
+ # Log input image statistics
1199
+ LOGGER.info(f"Image: {image_path.name}, shape: {test_image.shape}, range: [{test_image.min():.4f}, {test_image.max():.4f}]")
1200
+
1201
+ # Collect validation results
1202
+ all_passed = True
1203
+
1204
+ for i, name in enumerate(output_names):
1205
+ pt_output = pt_outputs[i].numpy()
1206
+
1207
+ # Find matching Core ML output
1208
+ coreml_key = None
1209
+ if name in coreml_outputs:
1210
+ coreml_key = name
1211
+ else:
1212
+ for key in coreml_outputs:
1213
+ base_name = name.split('_')[0]
1214
+ if base_name in key.lower():
1215
+ coreml_key = key
1216
+ break
1217
+ if coreml_key is None:
1218
+ coreml_key = list(coreml_outputs.keys())[i]
1219
+
1220
+ coreml_output = coreml_outputs[coreml_key]
1221
+
1222
+ if name == "quaternions_rotations":
1223
+ # Use QuaternionValidator
1224
+ quat_result = quat_validator.validate(pt_output, coreml_output, image_name=image_path.name)
1225
+
1226
+ LOGGER.info(f"Quaternions: mean={quat_result['stats']['mean']:.4f}°, p99={quat_result['stats']['p99']:.4f}°, max={quat_result['stats']['max']:.4f}°")
1227
+
1228
+ # Output outlier analysis
1229
+ if quat_result["outliers"]:
1230
+ for threshold, data in quat_result["outliers"].items():
1231
+ LOGGER.info(f" {threshold}: {data['count']} ({data['percentage']:.4f}%)")
1232
+
1233
+ if not quat_result["passed"]:
1234
+ LOGGER.warning(f" ⚠️ Quaternion validation failed: {'; '.join(quat_result['failure_reasons'])}")
1235
+ all_passed = False
1236
+ else:
1237
+ diff = np.abs(pt_output - coreml_output)
1238
+ output_tolerance = tolerances.get(name, 0.01)
1239
+ max_diff = np.max(diff)
1240
+
1241
+ LOGGER.info(f"{name}: max_diff={max_diff:.6f}, mean_diff={np.mean(diff):.6f}")
1242
+
1243
+ if max_diff > output_tolerance:
1244
+ LOGGER.warning(f" ⚠️ {name} failed: max_diff {max_diff:.6f} > tolerance {output_tolerance:.6f}")
1245
+ all_passed = False
1246
 
1247
  return all_passed
1248
 
 
1297
  action="store_true",
1298
  help="Enable verbose logging",
1299
  )
1300
+ parser.add_argument(
1301
+ "--input-image",
1302
+ type=Path,
1303
+ default=None,
1304
+ action="append",
1305
+ help="Path to input image for validation (can be specified multiple times, requires --validate)",
1306
+ )
1307
+ parser.add_argument(
1308
+ "--tolerance-mean",
1309
+ type=float,
1310
+ default=None,
1311
+ help="Custom mean angular tolerance in degrees (default: 0.01 for random, 0.1 for images)",
1312
+ )
1313
+ parser.add_argument(
1314
+ "--tolerance-p99",
1315
+ type=float,
1316
+ default=None,
1317
+ help="Custom P99 angular tolerance in degrees (default: 0.5 for random, 1.0 for images)",
1318
+ )
1319
+ parser.add_argument(
1320
+ "--tolerance-max",
1321
+ type=float,
1322
+ default=None,
1323
+ help="Custom max angular tolerance in degrees (default: 15.0)",
1324
+ )
1325
 
1326
  args = parser.parse_args()
1327
 
 
1360
 
1361
  # Validate if requested
1362
  if args.validate:
1363
+ if args.input_image:
1364
+ # Validate with one or more real input images
1365
+ validation_passed = validate_with_image_set(mlmodel, predictor, args.input_image, input_shape)
1366
+ else:
1367
+ # Validate with random input (default behavior)
1368
+ # Build custom angular tolerances from CLI args
1369
+ angular_tolerances = None
1370
+ if args.tolerance_mean or args.tolerance_p99 or args.tolerance_max:
1371
+ angular_tolerances = {
1372
+ "mean": args.tolerance_mean if args.tolerance_mean else 0.01,
1373
+ "p99": args.tolerance_p99 if args.tolerance_p99 else 0.5,
1374
+ "p99_9": 2.0,
1375
+ "max": args.tolerance_max if args.tolerance_max else 15.0,
1376
+ }
1377
+ validation_passed = validate_coreml_model(mlmodel, predictor, input_shape, angular_tolerances=angular_tolerances)
1378
 
1379
  if validation_passed:
1380
  LOGGER.info("✓ Validation passed!")
sharp.mlpackage/Data/com.apple.CoreML/model.mlmodel CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9bb25d6180e305984d1faeda322c433079999895ea78fda5ea9fe02d63d92bd3
3
  size 938777
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c36e0aa4ffde76052412f2c399cd140781e614ba732c33e9b72b9f8d7d1fe002
3
  size 938777
sharp.mlpackage/Manifest.json CHANGED
@@ -1,18 +1,18 @@
1
  {
2
  "fileFormatVersion": "1.0.0",
3
  "itemInfoEntries": {
4
- "B664B39D-DCBF-4A11-A7D6-74633C44EFCF": {
5
- "author": "com.apple.CoreML",
6
- "description": "CoreML Model Specification",
7
- "name": "model.mlmodel",
8
- "path": "com.apple.CoreML/model.mlmodel"
9
- },
10
- "D768A76E-EC7C-4ED0-91C5-FF591F7D5359": {
11
  "author": "com.apple.CoreML",
12
  "description": "CoreML Model Weights",
13
  "name": "weights",
14
  "path": "com.apple.CoreML/weights"
 
 
 
 
 
 
15
  }
16
  },
17
- "rootModelIdentifier": "B664B39D-DCBF-4A11-A7D6-74633C44EFCF"
18
  }
 
1
  {
2
  "fileFormatVersion": "1.0.0",
3
  "itemInfoEntries": {
4
+ "8EBB39F7-795C-4451-A2EE-090F6695386A": {
 
 
 
 
 
 
5
  "author": "com.apple.CoreML",
6
  "description": "CoreML Model Weights",
7
  "name": "weights",
8
  "path": "com.apple.CoreML/weights"
9
+ },
10
+ "97AA1BE5-373D-4A1B-B3DF-74F91F8B0AFE": {
11
+ "author": "com.apple.CoreML",
12
+ "description": "CoreML Model Specification",
13
+ "name": "model.mlmodel",
14
+ "path": "com.apple.CoreML/model.mlmodel"
15
  }
16
  },
17
+ "rootModelIdentifier": "97AA1BE5-373D-4A1B-B3DF-74F91F8B0AFE"
18
  }
test.ply CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:20c37e93b212cb2fee9cfbebf4b1abffb15baacf38e7983364f07a228be7ab14
3
  size 33030941
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b08f5a8cc6f1afffae48c257f0bf51b5f66dc0a13ff02aca16fc8ffe0a9d7f4f
3
  size 33030941