Sharp-onnx / convert_onnx.py
Kyle Pearson
updates to conversion
9bef2af
raw
history blame
25.8 kB
"""Convert SHARP PyTorch model to ONNX format."""
from __future__ import annotations
import argparse
import logging
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import onnx
import onnxruntime as ort
import torch
import torch.nn as nn
import torch.nn.functional as F
from sharp.models import PredictorParams, create_predictor
from sharp.models.predictor import RGBGaussianPredictor
from sharp.utils import io
LOGGER = logging.getLogger(__name__)
DEFAULT_MODEL_URL = "https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt"
OUTPUT_NAMES = [
"mean_vectors_3d_positions",
"singular_values_scales",
"quaternions_rotations",
"colors_rgb_linear",
"opacities_alpha_channel",
]
@dataclass
class ToleranceConfig:
random_tolerances: dict = None
image_tolerances: dict = None
angular_tolerances_random: dict = None
angular_tolerances_image: dict = None
def __post_init__(self):
if self.random_tolerances is None:
self.random_tolerances = {
"mean_vectors_3d_positions": 0.001,
"singular_values_scales": 0.0001,
"quaternions_rotations": 10.0, # Increased for ONNX numerical precision
"colors_rgb_linear": 0.002,
"opacities_alpha_channel": 0.005,
}
if self.image_tolerances is None:
self.image_tolerances = {
"mean_vectors_3d_positions": 3.5,
"singular_values_scales": 0.035,
"quaternions_rotations": 10.0, # Increased for ONNX numerical precision
"colors_rgb_linear": 0.01,
"opacities_alpha_channel": 0.05,
}
if self.angular_tolerances_random is None:
self.angular_tolerances_random = {"mean": 0.01, "p99": 0.1, "p99_9": 1.0, "max": 10.0} # Increased for ONNX precision
if self.angular_tolerances_image is None:
self.angular_tolerances_image = {"mean": 0.2, "p99": 2.0, "p99_9": 5.0, "max": 25.0}
class QuaternionValidator:
def __init__(self, angular_tolerances=None, enable_outlier_analysis=True, outlier_thresholds=None):
self.angular_tolerances = angular_tolerances or {"mean": 0.01, "p99": 0.5, "p99_9": 2.0, "max": 15.0}
self.enable_outlier_analysis = enable_outlier_analysis
self.outlier_thresholds = outlier_thresholds or [5.0, 10.0, 15.0]
@staticmethod
def canonicalize_quaternion(q):
abs_q = np.abs(q)
max_idx = np.argmax(abs_q, axis=-1, keepdims=True)
selector = np.zeros_like(q)
np.put_along_axis(selector, max_idx, 1.0, axis=-1)
max_sign = np.sum(q * selector, axis=-1, keepdims=True)
return np.where(max_sign < 0, -q, q)
@staticmethod
def compute_angular_differences(quats1, quats2):
n1 = np.linalg.norm(quats1, axis=-1, keepdims=True)
n2 = np.linalg.norm(quats2, axis=-1, keepdims=True)
q1 = quats1 / np.clip(n1, 1e-12, None)
q2 = quats2 / np.clip(n2, 1e-12, None)
q1 = QuaternionValidator.canonicalize_quaternion(q1)
q2 = QuaternionValidator.canonicalize_quaternion(q2)
dots = np.sum(q1 * q2, axis=-1)
dots_flipped = np.sum(q1 * (-q2), axis=-1)
dots = np.maximum(np.abs(dots), np.abs(dots_flipped))
dots = np.clip(dots, 0.0, 1.0)
ang_rad = 2.0 * np.arccos(dots)
ang_deg = np.degrees(ang_rad)
return ang_deg, {
"mean": float(np.mean(ang_deg)),
"std": float(np.std(ang_deg)),
"max": float(np.max(ang_deg)),
"p99": float(np.percentile(ang_deg, 99)),
"p99_9": float(np.percentile(ang_deg, 99.9)),
}
def validate(self, pt_quats, onnx_quats, image_name="Unknown"):
diff, stats = self.compute_angular_differences(pt_quats, onnx_quats)
passed = True
reasons = []
for k, t in self.angular_tolerances.items():
if k in stats and stats[k] > t:
passed = False
reasons.append(f"{k} angular {stats[k]:.4f} > {t:.4f}")
return {"image": image_name, "passed": passed, "failure_reasons": reasons, "stats": stats}
class SharpModelTraceable(nn.Module):
def __init__(self, predictor):
super().__init__()
self.init_model = predictor.init_model
self.feature_model = predictor.feature_model
self.monodepth_model = predictor.monodepth_model
self.prediction_head = predictor.prediction_head
self.gaussian_composer = predictor.gaussian_composer
self.depth_alignment = predictor.depth_alignment
def forward(self, image, disparity_factor):
monodepth_out = self.monodepth_model(image)
disp = monodepth_out.disparity
disp_factor = disparity_factor[:, None, None, None]
disp_clamped = disp.clamp(min=1e-4, max=1e4)
depth = disp_factor / disp_clamped
depth, _ = self.depth_alignment(depth, None, monodepth_out.decoder_features)
init_out = self.init_model(image, depth)
feats = self.feature_model(init_out.feature_input, encodings=monodepth_out.output_features)
deltas = self.prediction_head(feats)
gaussians = self.gaussian_composer(deltas, init_out.gaussian_base_values, init_out.global_scale)
quats = gaussians.quaternions
qnorm = torch.sqrt(torch.clamp(torch.sum(quats * quats, dim=-1, keepdim=True), min=1e-12))
quats = quats / qnorm
abs_q = torch.abs(quats)
max_idx = torch.argmax(abs_q, dim=-1, keepdim=True)
one_hot = torch.zeros_like(quats)
one_hot.scatter_(-1, max_idx, 1.0)
max_sign = torch.sum(quats * one_hot, dim=-1, keepdim=True)
quats = torch.where(max_sign < 0, -quats, quats).float()
return (gaussians.mean_vectors, gaussians.singular_values, quats, gaussians.colors, gaussians.opacities)
def cleanup_onnx_files(onnx_path):
"""Clean up ONNX model files including external data files."""
try:
if onnx_path.exists():
onnx_path.unlink()
LOGGER.info(f"Removed {onnx_path}")
except Exception as e:
LOGGER.warning(f"Could not remove {onnx_path}: {e}")
# Also clean up external data file with .onnx.data suffix
data_path = onnx_path.with_suffix('.onnx.data')
try:
if data_path.exists():
data_path.unlink()
LOGGER.info(f"Removed {data_path}")
except Exception as e:
LOGGER.warning(f"Could not remove {data_path}: {e}")
# Clean up any temporary files from conversion
temp_patterns = ["onnx__*", "monodepth_*", "feature_model*", "_Constant_*", "_init_model_*"]
import glob
for pattern in temp_patterns:
for f in glob.glob(pattern):
try:
Path(f).unlink()
LOGGER.info(f"Removed temporary file {f}")
except Exception:
pass
def cleanup_extraneous_files():
import glob, os
patterns = ["onnx__*", "monodepth_*", "feature_model*", "_Constant_*", "_init_model_*"]
for p in patterns:
for f in glob.glob(p):
try:
os.remove(f)
except Exception:
pass
def load_sharp_model(checkpoint_path=None):
if checkpoint_path is None:
LOGGER.info(f"Downloading model from {DEFAULT_MODEL_URL}")
state_dict = torch.hub.load_state_dict_from_url(DEFAULT_MODEL_URL, progress=True)
else:
LOGGER.info(f"Loading checkpoint from {checkpoint_path}")
state_dict = torch.load(checkpoint_path, weights_only=True, map_location="cpu")
predictor = create_predictor(PredictorParams())
predictor.load_state_dict(state_dict)
predictor.eval()
return predictor
def convert_to_fp16(onnx_path):
"""Convert an ONNX model to FP16 precision.
This function loads an ONNX model, converts all float32 initializers to float16,
and also updates the input/output types to float16 for proper execution.
The result is a smaller model with faster inference on FP16-capable hardware.
"""
LOGGER.info(f"Converting {onnx_path} to FP16...")
# Load the model
model = onnx.load(str(onnx_path))
# Convert all float tensors (initializers/weights) to float16
for tensor in model.graph.initializer:
if tensor.data_type == onnx.TensorProto.FLOAT:
float16_tensor = onnx.numpy_helper.to_array(tensor).astype(np.float16)
tensor.CopyFrom(onnx.numpy_helper.from_array(float16_tensor, tensor.name))
# Convert input types to float16 (if they are float32)
for inp in model.graph.input:
# Skip if this is an initializer (has the same name in initializer list)
if any(init.name == inp.name for init in model.graph.initializer):
continue
if inp.type.tensor_type.elem_type == onnx.TensorProto.FLOAT:
inp.type.tensor_type.elem_type = onnx.TensorProto.FLOAT16
# Convert output types to float16 (if they are float32)
for out in model.graph.output:
if out.type.tensor_type.elem_type == onnx.TensorProto.FLOAT:
out.type.tensor_type.elem_type = onnx.TensorProto.FLOAT16
# Update the opset domain to at least 13 for better FP16 support
for opset in model.opset_import:
if opset.domain == "" and opset.version < 13:
opset.version = 13
# Add AI on Edge opset if not present (improves cross-device compatibility)
has_ai_onnx_edge = False
for opset in model.opset_import:
if opset.domain == "com.microsoft":
has_ai_onnx_edge = True
break
if not has_ai_onnx_edge:
opset = model.opset_import.add()
opset.domain = "com.microsoft"
opset.version = 1
# Save the FP16 model
onnx.save(model, str(onnx_path))
size_mb = Path(onnx_path).stat().st_size / (1024 * 1024)
LOGGER.info(f"FP16 model saved: {onnx_path} ({size_mb:.2f} MB)")
return onnx_path
def convert_to_onnx(predictor, output_path, input_shape=(1536, 1536), use_external_data=None, fp16=False):
LOGGER.info("Exporting to ONNX format...")
predictor.depth_alignment.scale_map_estimator = None
model = SharpModelTraceable(predictor)
model.eval()
LOGGER.info("Pre-warming model...")
with torch.no_grad():
for _ in range(3):
_ = model(torch.randn(1, 3, input_shape[0], input_shape[1]), torch.tensor([1.0]))
cleanup_onnx_files(output_path)
h, w = input_shape
torch.manual_seed(42)
example_image = torch.randn(1, 3, h, w)
example_disparity = torch.tensor([1.0])
LOGGER.info(f"Exporting to ONNX: {output_path}")
# Dynamic axes: opacities has shape (1, N) so axis 0 is the batch, axis 1 is num_gaussians
# All other outputs have shape (1, N, C) where C is 3, 3, 4, 3 respectively
dynamic_axes = {}
for name in OUTPUT_NAMES:
if name == "opacities_alpha_channel":
# opacities is 2D: (batch, num_gaussians)
dynamic_axes[name] = {0: 'batch', 1: 'num_gaussians'}
else:
# All other outputs are 3D: (batch, num_gaussians, channels)
dynamic_axes[name] = {0: 'batch', 1: 'num_gaussians'}
torch.onnx.export(
model, (example_image, example_disparity), str(output_path),
export_params=True, verbose=False,
input_names=['image', 'disparity_factor'],
output_names=OUTPUT_NAMES,
dynamic_axes=dynamic_axes,
opset_version=15, # Use opset 15 for better browser compatibility
)
# Handle external data based on use_external_data parameter
try:
model_proto = onnx.load(str(output_path))
model_size_mb = model_proto.ByteSize() / (1024 * 1024)
LOGGER.info(f"Model size: {model_size_mb:.2f} MB")
# Default: use external data for models > 100MB (not typical for browser)
# use_external_data=True: always use external data
# use_external_data=False: never use external data (inline mode for browser)
use_ext = use_external_data if use_external_data is not None else (model_size_mb > 100)
if use_ext:
LOGGER.info("Saving with external data format...")
data_path = output_path.with_suffix('.onnx.data')
onnx.save_model(model_proto, str(output_path), save_as_external_data=True,
all_tensors_to_one_file=True, location=data_path.name)
LOGGER.info(f"External data saved to: {data_path}")
else:
LOGGER.info("Using inline data format (no external .onnx.data file needed)")
except Exception as e:
LOGGER.warning(f"External data format check failed: {e}")
try:
onnx.checker.check_model(str(output_path))
LOGGER.info("ONNX model validation passed")
except Exception as e:
LOGGER.warning(f"ONNX model validation skipped: {e}")
# Apply FP16 quantization if requested
if fp16:
convert_to_fp16(output_path)
cleanup_extraneous_files()
return output_path
def find_onnx_output_key(name, onnx_outputs):
if name in onnx_outputs:
return name
for key in onnx_outputs:
if name.split('_')[0] in key.lower():
return key
return list(onnx_outputs.keys())[OUTPUT_NAMES.index(name) if name in OUTPUT_NAMES else 0]
def load_and_preprocess_image(image_path, target_size=(1536, 1536)):
LOGGER.info(f"Loading image from {image_path}")
image_np, orig_size, f_px = io.load_rgb(image_path)
# Fallback to getting size from array if orig_size is None
if orig_size is None:
orig_size = (image_np.shape[1], image_np.shape[0])
LOGGER.info(f"Original size: {orig_size}, focal: {f_px:.2f}px")
tensor = torch.from_numpy(image_np).float() / 255.0
tensor = tensor.permute(2, 0, 1)
if (orig_size[0], orig_size[1]) != (target_size[1], target_size[0]):
LOGGER.info(f"Resizing to {target_size[1]}x{target_size[0]}")
tensor = F.interpolate(tensor.unsqueeze(0), size=target_size, mode="bilinear", align_corners=True).squeeze(0)
tensor = tensor.unsqueeze(0)
LOGGER.info(f"Preprocessed shape: {tensor.shape}, range: [{tensor.min():.4f}, {tensor.max():.4f}]")
return tensor, f_px, orig_size
def run_inference_pair(pytorch_model, onnx_path, image_tensor, disparity_factor=1.0, log_internals=False):
wrapper = SharpModelTraceable(pytorch_model)
wrapper.eval()
image_tensor = image_tensor.float()
disp_pt = torch.tensor([disparity_factor], dtype=torch.float32)
with torch.no_grad():
pt_outputs = wrapper(image_tensor, disp_pt)
pt_np = [o.numpy() for o in pt_outputs]
session = ort.InferenceSession(str(onnx_path), providers=['CPUExecutionProvider'])
onnx_inputs = {"image": image_tensor.numpy(), "disparity_factor": np.array([disparity_factor], dtype=np.float32)}
onnx_raw = session.run(None, onnx_inputs)
LOGGER.info(f"ONNX raw outputs count: {len(onnx_raw)}, first shape: {onnx_raw[0].shape if len(onnx_raw) > 0 else 'N/A'}")
# Check if outputs are already separated
if len(onnx_raw) == 5:
# ONNX returns separate outputs
onnx_splits = list(onnx_raw)
elif len(onnx_raw) == 1:
# ONNX returns concatenated output - split it
total_size = onnx_raw[0].shape[-1]
LOGGER.info(f"ONNX single output total size: {total_size}")
# Cumulative sizes: positions(3) + scales(3) + quats(4) + colors(3) + opacities(1) = 14
sizes = [3, 3, 4, 3, 1]
start = 0
onnx_splits = []
for i, size in enumerate(sizes):
onnx_splits.append(onnx_raw[0][:, :, start:start+size])
start += size
else:
onnx_splits = list(onnx_raw)
return pt_np, onnx_splits
def format_validation_table(results, image_name="", include_image=False):
lines = []
if include_image:
lines.append("| Image | Output | Max Diff | Mean Diff | P99 Diff | Status |")
lines.append("|-------|--------|----------|-----------|----------|--------|")
for r in results:
name = r["output"].replace("_", " ").title()
status = "PASS" if r["passed"] else "FAIL"
lines.append(f"| {image_name} | {name} | {r['max_diff']} | {r['mean_diff']} | {r['p99_diff']} | {status} |")
else:
lines.append("| Output | Max Diff | Mean Diff | P99 Diff | Status |")
lines.append("|--------|----------|-----------|----------|--------|")
for r in results:
name = r["output"].replace("_", " ").title()
status = "PASS" if r["passed"] else "FAIL"
lines.append(f"| {name} | {r['max_diff']} | {r['mean_diff']} | {r['p99_diff']} | {status} |")
return "\n".join(lines)
def validate_with_image(onnx_path, pytorch_model, image_path, input_shape=(1536, 1536)):
LOGGER.info(f"Validating with image: {image_path}")
test_image, f_px, (w, h) = load_and_preprocess_image(image_path, input_shape)
disparity_factor = f_px / w
LOGGER.info(f"Using disparity_factor = {disparity_factor:.6f}")
pt_outputs, onnx_out = run_inference_pair(pytorch_model, onnx_path, test_image, disparity_factor)
LOGGER.info(f"PyTorch outputs shapes: {[o.shape for o in pt_outputs]}")
LOGGER.info(f"ONNX output shapes: {[o.shape for o in onnx_out]}")
tolerance_config = ToleranceConfig()
tolerances = tolerance_config.image_tolerances
quat_validator = QuaternionValidator(angular_tolerances=tolerance_config.angular_tolerances_image)
all_passed = True
results = []
for i, name in enumerate(OUTPUT_NAMES):
pt_out = pt_outputs[i]
onnx_output = onnx_out[i]
result = {"output": name, "passed": True, "failure_reason": ""}
if name == "quaternions_rotations":
quat_result = quat_validator.validate(pt_out, onnx_output, image_path.name)
result.update({
"max_diff": f"{quat_result['stats']['max']:.6f}",
"mean_diff": f"{quat_result['stats']['mean']:.6f}",
"p99_diff": f"{quat_result['stats']['p99']:.6f}",
"passed": quat_result["passed"],
"failure_reason": "; ".join(quat_result["failure_reasons"]),
})
if not quat_result["passed"]:
all_passed = False
else:
diff = np.abs(pt_out - onnx_output)
tol = tolerances.get(name, 0.01)
result.update({
"max_diff": f"{np.max(diff):.6f}",
"mean_diff": f"{np.mean(diff):.6f}",
"p99_diff": f"{np.percentile(diff, 99):.6f}",
})
if np.max(diff) > tol:
result["passed"] = False
result["failure_reason"] = f"max diff {np.max(diff):.6f} > tol {tol:.6f}"
all_passed = False
results.append(result)
LOGGER.info(f"\n### Validation Results: {image_path.name}\n")
LOGGER.info(format_validation_table(results, image_path.name, include_image=True))
LOGGER.info("")
return all_passed
def validate_onnx_model(onnx_path, pytorch_model, input_shape=(1536, 1536), angular_tolerances=None):
LOGGER.info("Validating ONNX model against PyTorch...")
np.random.seed(42)
torch.manual_seed(42)
test_image = np.random.rand(1, 3, input_shape[0], input_shape[1]).astype(np.float32)
test_disp = np.array([1.0], dtype=np.float32)
wrapper = SharpModelTraceable(pytorch_model)
wrapper.eval()
with torch.no_grad():
pt_out = wrapper(torch.from_numpy(test_image), torch.from_numpy(test_disp))
session = ort.InferenceSession(str(onnx_path), providers=['CPUExecutionProvider'])
onnx_raw = session.run(None, {"image": test_image, "disparity_factor": test_disp})
# Use same splitting logic as run_inference_pair
if len(onnx_raw) == 5:
onnx_splits = list(onnx_raw)
elif len(onnx_raw) == 1:
sizes = [3, 3, 4, 3, 1]
start = 0
onnx_splits = []
for size in sizes:
onnx_splits.append(onnx_raw[0][:, :, start:start+size])
start += size
else:
onnx_splits = list(onnx_raw)
tolerance_config = ToleranceConfig()
tolerances = tolerance_config.random_tolerances
quat_validator = QuaternionValidator(angular_tolerances=angular_tolerances or tolerance_config.angular_tolerances_random)
all_passed = True
results = []
for i, name in enumerate(OUTPUT_NAMES):
pt_o = pt_out[i].numpy()
onnx_o = onnx_splits[i]
result = {"output": name, "passed": True, "failure_reason": ""}
if name == "quaternions_rotations":
qr = quat_validator.validate(pt_o, onnx_o, "Random")
result.update({
"max_diff": f"{qr['stats']['max']:.6f}",
"mean_diff": f"{qr['stats']['mean']:.6f}",
"p99_diff": f"{qr['stats']['p99']:.6f}",
"passed": qr["passed"],
"failure_reason": "; ".join(qr["failure_reasons"]),
})
if not qr["passed"]:
all_passed = False
else:
diff = np.abs(pt_o - onnx_o)
tol = tolerances.get(name, 0.01)
result.update({
"max_diff": f"{np.max(diff):.6f}",
"mean_diff": f"{np.mean(diff):.6f}",
"p99_diff": f"{np.percentile(diff, 99):.6f}",
})
if np.max(diff) > tol:
result["passed"] = False
result["failure_reason"] = f"max diff {np.max(diff):.6f} > tol {tol:.6f}"
all_passed = False
results.append(result)
LOGGER.info("\n### Random Validation Results\n")
LOGGER.info(format_validation_table(results))
LOGGER.info("")
return all_passed
def main():
parser = argparse.ArgumentParser(description="Convert SHARP PyTorch model to ONNX format")
parser.add_argument("-c", "--checkpoint", type=Path, default=None, help="Path to PyTorch checkpoint")
parser.add_argument("-o", "--output", type=Path, default=Path("sharp.onnx"), help="Output path for ONNX model")
parser.add_argument("--height", type=int, default=1536, help="Input image height")
parser.add_argument("--width", type=int, default=1536, help="Input image width")
parser.add_argument("--validate", action="store_true", help="Validate ONNX model against PyTorch")
parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose logging")
parser.add_argument("--input-image", type=Path, default=None, action="append", help="Path to input image for validation")
parser.add_argument("--no-external-data", action="store_true", help="Save model with inline data (no .onnx.data file needed)")
parser.add_argument("--fp16", action="store_true", help="Quantize model to FP16 precision (half-precision)")
parser.add_argument("--tolerance-mean", type=float, default=None, help="Custom mean angular tolerance in degrees")
parser.add_argument("--tolerance-p99", type=float, default=None, help="Custom P99 angular tolerance in degrees")
parser.add_argument("--tolerance-max", type=float, default=None, help="Custom max angular tolerance in degrees")
args = parser.parse_args()
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
LOGGER.info("Loading SHARP model...")
predictor = load_sharp_model(args.checkpoint)
input_shape = (args.height, args.width)
LOGGER.info(f"Converting to ONNX: {args.output}")
# Use inline data format for browser deployment (--no-external-data flag or default for web)
use_external_data = not args.no_external_data
convert_to_onnx(predictor, args.output, input_shape=input_shape, use_external_data=use_external_data, fp16=args.fp16)
LOGGER.info(f"ONNX model saved to {args.output}")
# Skip validation for FP16 models since they have inherent precision differences from FP32
if args.validate and args.fp16:
LOGGER.info("Validation skipped for FP16 model (precision differences expected)")
LOGGER.info("Conversion complete!")
return 0
if args.validate:
if args.input_image:
for img_path in args.input_image:
if not img_path.exists():
LOGGER.error(f"Image not found: {img_path}")
return 1
passed = validate_with_image(args.output, predictor, img_path, input_shape)
if not passed:
LOGGER.error(f"Validation failed for {img_path}")
return 1
else:
angular_tolerances = None
if args.tolerance_mean or args.tolerance_p99 or args.tolerance_max:
angular_tolerances = {
"mean": args.tolerance_mean if args.tolerance_mean else 0.01,
"p99": args.tolerance_p99 if args.tolerance_p99 else 0.5,
"p99_9": 2.0,
"max": args.tolerance_max if args.tolerance_max else 15.0,
}
passed = validate_onnx_model(args.output, predictor, input_shape, angular_tolerances=angular_tolerances)
if passed:
LOGGER.info("Validation passed!")
else:
LOGGER.error("Validation failed!")
return 1
LOGGER.info("Conversion complete!")
return 0
if __name__ == "__main__":
exit(main())