Kyle Pearson commited on
Commit
cfcc093
·
1 Parent(s): 1df9b82

Fix core ml precision mismatches, update quaternion handling, add global scale storage, enhance detailed model validation, improve multi-image support, fix metadata UUIDs, enable sRGB color output conversion.

Browse files
README.md CHANGED
@@ -63,7 +63,7 @@ Use the provided [sharp.swift](sharp.swift) inference script to load the model a
63
  swiftc -O -o run_sharp sharp.swift -framework CoreML -framework CoreImage -framework AppKit
64
 
65
  # Run inference on an image and decimate the output by 50%
66
- ./run_sharp sharp.mlpackage test.png test.ply -d 0.5
67
  ```
68
 
69
  > Inference on an Apple M4 Max takes ~1.9 seconds.
 
63
  swiftc -O -o run_sharp sharp.swift -framework CoreML -framework CoreImage -framework AppKit
64
 
65
  # Run inference on an image and decimate the output by 50%
66
+ ./run_sharp sharp.mlpackage city.png city.ply -d 0.5
67
  ```
68
 
69
  > Inference on an Apple M4 Max takes ~1.9 seconds.
convert.py CHANGED
@@ -84,20 +84,23 @@ class SharpModelTraceable(nn.Module):
84
  monodepth_output = self.monodepth_model(image)
85
  monodepth_disparity = monodepth_output.disparity
86
 
87
- # Convert disparity to depth with higher precision
88
- # Use tighter clamp bounds and higher precision intermediate computation
89
  disparity_factor_expanded = disparity_factor[:, None, None, None]
90
-
91
- # Cast to float64 for more precise division, then back to float32
92
- disparity_clamped = monodepth_disparity.clamp(min=1e-6, max=1e4)
93
- monodepth = disparity_factor_expanded.double() / disparity_clamped.double()
94
- monodepth = monodepth.float()
95
 
96
  # Apply depth alignment (inference mode)
97
  monodepth, _ = self.depth_alignment(monodepth, None, monodepth_output.decoder_features)
98
 
99
  # Initialize gaussians
100
  init_output = self.init_model(image, monodepth)
 
 
 
 
101
 
102
  # Extract features
103
  image_features = self.feature_model(
@@ -116,17 +119,26 @@ class SharpModelTraceable(nn.Module):
116
  )
117
 
118
  # Normalize quaternions for consistent validation and inference
119
- # This is critical for CoreML conversion accuracy
 
 
 
 
 
 
 
 
 
120
  quaternions = gaussians.quaternions
121
 
122
- # Use double precision for quaternion normalization to reduce numerical errors
123
- quaternions_fp64 = quaternions.double()
124
- quat_norm_sq = torch.sum(quaternions_fp64 * quaternions_fp64, dim=-1, keepdim=True)
125
- quat_norm = torch.sqrt(torch.clamp(quat_norm_sq, min=1e-16))
126
- quaternions_normalized = quaternions_fp64 / quat_norm
127
 
128
  # Apply sign canonicalization for consistent representation
129
- # Find the component with the largest absolute value
130
  abs_quat = torch.abs(quaternions_normalized)
131
  max_idx = torch.argmax(abs_quat, dim=-1, keepdim=True)
132
 
@@ -646,6 +658,50 @@ class QuaternionValidator:
646
  }
647
 
648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
649
  def validate_coreml_model(
650
  mlmodel: ct.models.MLModel,
651
  pytorch_model: RGBGaussianPredictor,
@@ -1107,7 +1163,7 @@ def validate_with_image_set(
1107
  quat_validator = QuaternionValidator(angular_tolerances=angular_tolerances)
1108
 
1109
  all_passed = True
1110
- image_results = []
1111
 
1112
  for image_path in image_paths:
1113
  if not image_path.exists():
@@ -1117,29 +1173,139 @@ def validate_with_image_set(
1117
 
1118
  LOGGER.info(f"\n--- Validating with {image_path.name} ---")
1119
 
1120
- # Run validation for this image
1121
- passed = validate_with_single_image(
1122
  mlmodel, pytorch_model, image_path, input_shape, quat_validator
1123
  )
1124
- image_results.append({"image": image_path.name, "passed": passed})
1125
-
1126
- if not passed:
 
 
 
 
 
1127
  all_passed = False
1128
 
1129
- # Output summary table
1130
  LOGGER.info("\n" + "=" * 60)
1131
  LOGGER.info("### Multi-Image Validation Summary")
1132
- LOGGER.info("=" * 60)
1133
- LOGGER.info(f"| Image | Status |")
1134
- LOGGER.info("|-------|--------|")
 
 
 
 
1135
 
1136
- for result in image_results:
1137
- status = "✅ PASS" if result["passed"] else "❌ FAIL"
1138
- LOGGER.info(f"| {result['image']} | {status} |")
1139
 
1140
- LOGGER.info("")
1141
 
1142
- return all_passed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1143
 
1144
 
1145
  def validate_with_single_image(
@@ -1189,6 +1355,7 @@ def validate_with_single_image(
1189
  "singular_values_scales": 0.01,
1190
  "colors_rgb_linear": 0.01,
1191
  "opacities_alpha_channel": 0.05,
 
1192
  }
1193
 
1194
  # Use provided validator or create default
@@ -1200,6 +1367,7 @@ def validate_with_single_image(
1200
 
1201
  # Collect validation results
1202
  all_passed = True
 
1203
 
1204
  for i, name in enumerate(output_names):
1205
  pt_output = pt_outputs[i].numpy()
@@ -1218,32 +1386,46 @@ def validate_with_single_image(
1218
  coreml_key = list(coreml_outputs.keys())[i]
1219
 
1220
  coreml_output = coreml_outputs[coreml_key]
 
1221
 
1222
  if name == "quaternions_rotations":
1223
  # Use QuaternionValidator
1224
  quat_result = quat_validator.validate(pt_output, coreml_output, image_name=image_path.name)
1225
 
1226
- LOGGER.info(f"Quaternions: mean={quat_result['stats']['mean']:.4f}°, p99={quat_result['stats']['p99']:.4f}°, max={quat_result['stats']['max']:.4f}°")
1227
-
1228
- # Output outlier analysis
1229
- if quat_result["outliers"]:
1230
- for threshold, data in quat_result["outliers"].items():
1231
- LOGGER.info(f" {threshold}: {data['count']} ({data['percentage']:.4f}%)")
 
1232
 
1233
  if not quat_result["passed"]:
1234
- LOGGER.warning(f" ⚠️ Quaternion validation failed: {'; '.join(quat_result['failure_reasons'])}")
1235
  all_passed = False
1236
  else:
1237
  diff = np.abs(pt_output - coreml_output)
1238
  output_tolerance = tolerances.get(name, 0.01)
1239
  max_diff = np.max(diff)
1240
 
1241
- LOGGER.info(f"{name}: max_diff={max_diff:.6f}, mean_diff={np.mean(diff):.6f}")
 
 
 
 
1242
 
1243
  if max_diff > output_tolerance:
1244
- LOGGER.warning(f" ⚠️ {name} failed: max_diff {max_diff:.6f} > tolerance {output_tolerance:.6f}")
 
1245
  all_passed = False
1246
 
 
 
 
 
 
 
 
 
1247
  return all_passed
1248
 
1249
 
 
84
  monodepth_output = self.monodepth_model(image)
85
  monodepth_disparity = monodepth_output.disparity
86
 
87
+ # Convert disparity to depth - use float32 to match Core ML execution
88
+ # Core ML uses float32 precision, so using double() here creates a mismatch
89
  disparity_factor_expanded = disparity_factor[:, None, None, None]
90
+
91
+ # Clamp disparity to prevent numerical instability (matches model exactly)
92
+ disparity_clamped = monodepth_disparity.clamp(min=1e-4, max=1e4)
93
+ monodepth = disparity_factor_expanded / disparity_clamped
 
94
 
95
  # Apply depth alignment (inference mode)
96
  monodepth, _ = self.depth_alignment(monodepth, None, monodepth_output.decoder_features)
97
 
98
  # Initialize gaussians
99
  init_output = self.init_model(image, monodepth)
100
+
101
+ # Store global_scale for debugging if in eval mode (not during tracing)
102
+ if hasattr(self, '_store_global_scale'):
103
+ self._stored_global_scale = init_output.global_scale
104
 
105
  # Extract features
106
  image_features = self.feature_model(
 
119
  )
120
 
121
  # Normalize quaternions for consistent validation and inference
122
+ #
123
+ # IMPORTANT: The SHARP model does NOT canonicalize quaternions during inference.
124
+ # Quaternions are normalized to unit length but retain their sign ambiguity (q ≡ -q).
125
+ #
126
+ # We canonicalize here for two reasons:
127
+ # 1. Numerical validation: Ensures PyTorch and Core ML outputs can be compared directly
128
+ # 2. Consistency: Provides deterministic outputs for the same rotation
129
+ #
130
+ # This canonicalization is NOT required for rendering, as both q and -q represent
131
+ # the same 3D rotation. Renderers typically normalize quaternions internally.
132
  quaternions = gaussians.quaternions
133
 
134
+ # Normalize quaternions to unit length
135
+ # Use float32 to match Core ML precision
136
+ quat_norm_sq = torch.sum(quaternions * quaternions, dim=-1, keepdim=True)
137
+ quat_norm = torch.sqrt(torch.clamp(quat_norm_sq, min=1e-12))
138
+ quaternions_normalized = quaternions / quat_norm
139
 
140
  # Apply sign canonicalization for consistent representation
141
+ # Ensure the component with largest absolute value is positive
142
  abs_quat = torch.abs(quaternions_normalized)
143
  max_idx = torch.argmax(abs_quat, dim=-1, keepdim=True)
144
 
 
658
  }
659
 
660
 
661
+ def format_validation_table(
662
+ validation_results: list[dict],
663
+ image_name: str,
664
+ include_image_column: bool = False,
665
+ ) -> str:
666
+ """Format validation results as a markdown table.
667
+
668
+ Args:
669
+ validation_results: List of validation result dicts with keys:
670
+ output, max_diff, mean_diff, p99_diff, passed, etc.
671
+ image_name: Name of the image being validated.
672
+ include_image_column: Whether to include the image name as a column.
673
+
674
+ Returns:
675
+ Formatted markdown table as a string.
676
+ """
677
+ lines = []
678
+
679
+ if include_image_column:
680
+ lines.append("| Image | Output | Max Diff | Mean Diff | P99 Diff | Status |")
681
+ lines.append("|-------|--------|----------|-----------|----------|--------|")
682
+
683
+ for result in validation_results:
684
+ output_name = result["output"].replace("_", " ").title()
685
+ status = "✅ PASS" if result["passed"] else "❌ FAIL"
686
+ lines.append(
687
+ f"| {image_name} | {output_name} | {result['max_diff']} | "
688
+ f"{result['mean_diff']} | {result['p99_diff']} | {status} |"
689
+ )
690
+ else:
691
+ lines.append("| Output | Max Diff | Mean Diff | P99 Diff | Status |")
692
+ lines.append("|--------|----------|-----------|----------|--------|")
693
+
694
+ for result in validation_results:
695
+ output_name = result["output"].replace("_", " ").title()
696
+ status = "✅ PASS" if result["passed"] else "❌ FAIL"
697
+ lines.append(
698
+ f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | "
699
+ f"{result['p99_diff']} | {status} |"
700
+ )
701
+
702
+ return "\n".join(lines)
703
+
704
+
705
  def validate_coreml_model(
706
  mlmodel: ct.models.MLModel,
707
  pytorch_model: RGBGaussianPredictor,
 
1163
  quat_validator = QuaternionValidator(angular_tolerances=angular_tolerances)
1164
 
1165
  all_passed = True
1166
+ all_validation_results = []
1167
 
1168
  for image_path in image_paths:
1169
  if not image_path.exists():
 
1173
 
1174
  LOGGER.info(f"\n--- Validating with {image_path.name} ---")
1175
 
1176
+ # Run validation for this image and collect detailed results
1177
+ image_results = validate_with_single_image_detailed(
1178
  mlmodel, pytorch_model, image_path, input_shape, quat_validator
1179
  )
1180
+
1181
+ # Add image name to each result
1182
+ for result in image_results:
1183
+ result["image"] = image_path.name
1184
+ all_validation_results.append(result)
1185
+
1186
+ # Check if any results failed
1187
+ if not all(r["passed"] for r in image_results):
1188
  all_passed = False
1189
 
1190
+ # Output combined summary table with all images and outputs
1191
  LOGGER.info("\n" + "=" * 60)
1192
  LOGGER.info("### Multi-Image Validation Summary")
1193
+ LOGGER.info("=" * 60 + "\n")
1194
+
1195
+ # Generate combined table
1196
+ if all_validation_results:
1197
+ table = format_validation_table(all_validation_results, "", include_image_column=True)
1198
+ LOGGER.info(table)
1199
+ LOGGER.info("")
1200
 
1201
+ return all_passed
 
 
1202
 
 
1203
 
1204
+ def validate_with_single_image_detailed(
1205
+ mlmodel: ct.models.MLModel,
1206
+ pytorch_model: RGBGaussianPredictor,
1207
+ image_path: Path,
1208
+ input_shape: tuple[int, int],
1209
+ quat_validator: QuaternionValidator | None = None,
1210
+ ) -> list[dict]:
1211
+ """Validate with a single image and return detailed results.
1212
+
1213
+ Args:
1214
+ mlmodel: The Core ML model to validate.
1215
+ pytorch_model: The original PyTorch model.
1216
+ image_path: Path to the input image file.
1217
+ input_shape: Expected input image shape.
1218
+ quat_validator: Optional QuaternionValidator instance.
1219
+
1220
+ Returns:
1221
+ List of validation result dictionaries.
1222
+ """
1223
+ # Load and preprocess the input image
1224
+ test_image = load_and_preprocess_image(image_path, input_shape)
1225
+ test_disparity = np.array([1.0], dtype=np.float32)
1226
+
1227
+ # Run PyTorch model
1228
+ traceable_wrapper = SharpModelTraceable(pytorch_model)
1229
+ traceable_wrapper.eval()
1230
+
1231
+ with torch.no_grad():
1232
+ pt_outputs = traceable_wrapper(test_image, torch.from_numpy(test_disparity))
1233
+
1234
+ # Run Core ML model
1235
+ test_image_np = test_image.numpy()
1236
+ coreml_inputs = {
1237
+ "image": test_image_np,
1238
+ "disparity_factor": test_disparity,
1239
+ }
1240
+ coreml_outputs = mlmodel.predict(coreml_inputs)
1241
+
1242
+ # Output configuration
1243
+ output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"]
1244
+
1245
+ # Tolerances for real image validation
1246
+ tolerances = {
1247
+ "mean_vectors_3d_positions": 1.2,
1248
+ "singular_values_scales": 0.01,
1249
+ "colors_rgb_linear": 0.01,
1250
+ "opacities_alpha_channel": 0.05,
1251
+ "quaternions_rotations": 5.0,
1252
+ }
1253
+
1254
+ # Use provided validator or create default
1255
+ if quat_validator is None:
1256
+ quat_validator = QuaternionValidator()
1257
+
1258
+ # Collect validation results
1259
+ validation_results = []
1260
+
1261
+ for i, name in enumerate(output_names):
1262
+ pt_output = pt_outputs[i].numpy()
1263
+
1264
+ # Find matching Core ML output
1265
+ coreml_key = None
1266
+ if name in coreml_outputs:
1267
+ coreml_key = name
1268
+ else:
1269
+ for key in coreml_outputs:
1270
+ base_name = name.split('_')[0]
1271
+ if base_name in key.lower():
1272
+ coreml_key = key
1273
+ break
1274
+ if coreml_key is None:
1275
+ coreml_key = list(coreml_outputs.keys())[i]
1276
+
1277
+ coreml_output = coreml_outputs[coreml_key]
1278
+ result = {"output": name, "passed": True, "failure_reason": ""}
1279
+
1280
+ if name == "quaternions_rotations":
1281
+ # Use QuaternionValidator
1282
+ quat_result = quat_validator.validate(pt_output, coreml_output, image_name=image_path.name)
1283
+
1284
+ result.update({
1285
+ "max_diff": f"{quat_result['stats']['max']:.6f}",
1286
+ "mean_diff": f"{quat_result['stats']['mean']:.6f}",
1287
+ "p99_diff": f"{quat_result['stats']['p99']:.6f}",
1288
+ "passed": quat_result["passed"],
1289
+ "failure_reason": "; ".join(quat_result["failure_reasons"]) if quat_result["failure_reasons"] else "",
1290
+ })
1291
+ else:
1292
+ diff = np.abs(pt_output - coreml_output)
1293
+ output_tolerance = tolerances.get(name, 0.01)
1294
+ max_diff = np.max(diff)
1295
+
1296
+ result.update({
1297
+ "max_diff": f"{max_diff:.6f}",
1298
+ "mean_diff": f"{np.mean(diff):.6f}",
1299
+ "p99_diff": f"{np.percentile(diff, 99):.6f}",
1300
+ })
1301
+
1302
+ if max_diff > output_tolerance:
1303
+ result["passed"] = False
1304
+ result["failure_reason"] = f"max diff {max_diff:.6f} > tolerance {output_tolerance:.6f}"
1305
+
1306
+ validation_results.append(result)
1307
+
1308
+ return validation_results
1309
 
1310
 
1311
  def validate_with_single_image(
 
1355
  "singular_values_scales": 0.01,
1356
  "colors_rgb_linear": 0.01,
1357
  "opacities_alpha_channel": 0.05,
1358
+ "quaternions_rotations": 5.0,
1359
  }
1360
 
1361
  # Use provided validator or create default
 
1367
 
1368
  # Collect validation results
1369
  all_passed = True
1370
+ validation_results = []
1371
 
1372
  for i, name in enumerate(output_names):
1373
  pt_output = pt_outputs[i].numpy()
 
1386
  coreml_key = list(coreml_outputs.keys())[i]
1387
 
1388
  coreml_output = coreml_outputs[coreml_key]
1389
+ result = {"output": name, "passed": True, "failure_reason": ""}
1390
 
1391
  if name == "quaternions_rotations":
1392
  # Use QuaternionValidator
1393
  quat_result = quat_validator.validate(pt_output, coreml_output, image_name=image_path.name)
1394
 
1395
+ result.update({
1396
+ "max_diff": f"{quat_result['stats']['max']:.6f}",
1397
+ "mean_diff": f"{quat_result['stats']['mean']:.6f}",
1398
+ "p99_diff": f"{quat_result['stats']['p99']:.6f}",
1399
+ "passed": quat_result["passed"],
1400
+ "failure_reason": "; ".join(quat_result["failure_reasons"]) if quat_result["failure_reasons"] else "",
1401
+ })
1402
 
1403
  if not quat_result["passed"]:
 
1404
  all_passed = False
1405
  else:
1406
  diff = np.abs(pt_output - coreml_output)
1407
  output_tolerance = tolerances.get(name, 0.01)
1408
  max_diff = np.max(diff)
1409
 
1410
+ result.update({
1411
+ "max_diff": f"{max_diff:.6f}",
1412
+ "mean_diff": f"{np.mean(diff):.6f}",
1413
+ "p99_diff": f"{np.percentile(diff, 99):.6f}",
1414
+ })
1415
 
1416
  if max_diff > output_tolerance:
1417
+ result["passed"] = False
1418
+ result["failure_reason"] = f"max diff {max_diff:.6f} > tolerance {output_tolerance:.6f}"
1419
  all_passed = False
1420
 
1421
+ validation_results.append(result)
1422
+
1423
+ # Output validation results as markdown table
1424
+ LOGGER.info(f"\n### Validation Results: {image_path.name}\n")
1425
+ table = format_validation_table(validation_results, image_path.name, include_image_column=False)
1426
+ LOGGER.info(table)
1427
+ LOGGER.info("")
1428
+
1429
  return all_passed
1430
 
1431
 
sharp.mlpackage/Data/com.apple.CoreML/model.mlmodel CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c36e0aa4ffde76052412f2c399cd140781e614ba732c33e9b72b9f8d7d1fe002
3
- size 938777
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca2a548947bdf1616a9c7ddf093c27dc0aeb8225a1e50cb40eb098d7aa47a2b5
3
+ size 938769
sharp.mlpackage/Manifest.json CHANGED
@@ -1,18 +1,18 @@
1
  {
2
  "fileFormatVersion": "1.0.0",
3
  "itemInfoEntries": {
4
- "8EBB39F7-795C-4451-A2EE-090F6695386A": {
5
  "author": "com.apple.CoreML",
6
  "description": "CoreML Model Weights",
7
  "name": "weights",
8
  "path": "com.apple.CoreML/weights"
9
  },
10
- "97AA1BE5-373D-4A1B-B3DF-74F91F8B0AFE": {
11
  "author": "com.apple.CoreML",
12
  "description": "CoreML Model Specification",
13
  "name": "model.mlmodel",
14
  "path": "com.apple.CoreML/model.mlmodel"
15
  }
16
  },
17
- "rootModelIdentifier": "97AA1BE5-373D-4A1B-B3DF-74F91F8B0AFE"
18
  }
 
1
  {
2
  "fileFormatVersion": "1.0.0",
3
  "itemInfoEntries": {
4
+ "1504890B-E584-4EC2-A1CF-F87AE1A1BAA0": {
5
  "author": "com.apple.CoreML",
6
  "description": "CoreML Model Weights",
7
  "name": "weights",
8
  "path": "com.apple.CoreML/weights"
9
  },
10
+ "D59C5780-FA59-423A-8088-BCF64225C1B3": {
11
  "author": "com.apple.CoreML",
12
  "description": "CoreML Model Specification",
13
  "name": "model.mlmodel",
14
  "path": "com.apple.CoreML/model.mlmodel"
15
  }
16
  },
17
+ "rootModelIdentifier": "D59C5780-FA59-423A-8088-BCF64225C1B3"
18
  }
sharp.swift CHANGED
@@ -486,6 +486,8 @@ class SHARPModelRunner {
486
  }
487
 
488
  // Colors: Convert linearRGB -> sRGB -> spherical harmonics
 
 
489
  let colorR = colorPtr[i * 3 + 0]
490
  let colorG = colorPtr[i * 3 + 1]
491
  let colorB = colorPtr[i * 3 + 2]
 
486
  }
487
 
488
  // Colors: Convert linearRGB -> sRGB -> spherical harmonics
489
+ // Model outputs linearRGB colors for proper alpha blending
490
+ // We convert to sRGB for compatibility with public renderers
491
  let colorR = colorPtr[i * 3 + 0]
492
  let colorG = colorPtr[i * 3 + 1]
493
  let colorB = colorPtr[i * 3 + 2]