| """Convert SHARP PyTorch model to ONNX format.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import logging |
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
| import numpy as np |
| import onnx |
| import onnxruntime as ort |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from sharp.models import PredictorParams, create_predictor |
| from sharp.models.predictor import RGBGaussianPredictor |
| from sharp.utils import io |
|
|
| LOGGER = logging.getLogger(__name__) |
| DEFAULT_MODEL_URL = "https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt" |
|
|
| OUTPUT_NAMES = [ |
| "mean_vectors_3d_positions", |
| "singular_values_scales", |
| "quaternions_rotations", |
| "colors_rgb_linear", |
| "opacities_alpha_channel", |
| ] |
|
|
|
|
| @dataclass |
| class ToleranceConfig: |
| random_tolerances: dict = None |
| image_tolerances: dict = None |
| angular_tolerances_random: dict = None |
| angular_tolerances_image: dict = None |
|
|
| def __post_init__(self): |
| if self.random_tolerances is None: |
| self.random_tolerances = { |
| "mean_vectors_3d_positions": 0.001, |
| "singular_values_scales": 0.0001, |
| "quaternions_rotations": 10.0, |
| "colors_rgb_linear": 0.002, |
| "opacities_alpha_channel": 0.005, |
| } |
| if self.image_tolerances is None: |
| self.image_tolerances = { |
| "mean_vectors_3d_positions": 3.5, |
| "singular_values_scales": 0.035, |
| "quaternions_rotations": 10.0, |
| "colors_rgb_linear": 0.01, |
| "opacities_alpha_channel": 0.05, |
| } |
| if self.angular_tolerances_random is None: |
| self.angular_tolerances_random = {"mean": 0.01, "p99": 0.1, "p99_9": 1.0, "max": 10.0} |
| if self.angular_tolerances_image is None: |
| self.angular_tolerances_image = {"mean": 0.2, "p99": 2.0, "p99_9": 5.0, "max": 25.0} |
|
|
|
|
| class QuaternionValidator: |
| def __init__(self, angular_tolerances=None, enable_outlier_analysis=True, outlier_thresholds=None): |
| self.angular_tolerances = angular_tolerances or {"mean": 0.01, "p99": 0.5, "p99_9": 2.0, "max": 15.0} |
| self.enable_outlier_analysis = enable_outlier_analysis |
| self.outlier_thresholds = outlier_thresholds or [5.0, 10.0, 15.0] |
|
|
| @staticmethod |
| def canonicalize_quaternion(q): |
| abs_q = np.abs(q) |
| max_idx = np.argmax(abs_q, axis=-1, keepdims=True) |
| selector = np.zeros_like(q) |
| np.put_along_axis(selector, max_idx, 1.0, axis=-1) |
| max_sign = np.sum(q * selector, axis=-1, keepdims=True) |
| return np.where(max_sign < 0, -q, q) |
|
|
| @staticmethod |
| def compute_angular_differences(quats1, quats2): |
| n1 = np.linalg.norm(quats1, axis=-1, keepdims=True) |
| n2 = np.linalg.norm(quats2, axis=-1, keepdims=True) |
| q1 = quats1 / np.clip(n1, 1e-12, None) |
| q2 = quats2 / np.clip(n2, 1e-12, None) |
| q1 = QuaternionValidator.canonicalize_quaternion(q1) |
| q2 = QuaternionValidator.canonicalize_quaternion(q2) |
| dots = np.sum(q1 * q2, axis=-1) |
| dots_flipped = np.sum(q1 * (-q2), axis=-1) |
| dots = np.maximum(np.abs(dots), np.abs(dots_flipped)) |
| dots = np.clip(dots, 0.0, 1.0) |
| ang_rad = 2.0 * np.arccos(dots) |
| ang_deg = np.degrees(ang_rad) |
| return ang_deg, { |
| "mean": float(np.mean(ang_deg)), |
| "std": float(np.std(ang_deg)), |
| "max": float(np.max(ang_deg)), |
| "p99": float(np.percentile(ang_deg, 99)), |
| "p99_9": float(np.percentile(ang_deg, 99.9)), |
| } |
|
|
| def validate(self, pt_quats, onnx_quats, image_name="Unknown"): |
| diff, stats = self.compute_angular_differences(pt_quats, onnx_quats) |
| passed = True |
| reasons = [] |
| for k, t in self.angular_tolerances.items(): |
| if k in stats and stats[k] > t: |
| passed = False |
| reasons.append(f"{k} angular {stats[k]:.4f} > {t:.4f}") |
| return {"image": image_name, "passed": passed, "failure_reasons": reasons, "stats": stats} |
|
|
|
|
| class SharpModelTraceable(nn.Module): |
| def __init__(self, predictor): |
| super().__init__() |
| self.init_model = predictor.init_model |
| self.feature_model = predictor.feature_model |
| self.monodepth_model = predictor.monodepth_model |
| self.prediction_head = predictor.prediction_head |
| self.gaussian_composer = predictor.gaussian_composer |
| self.depth_alignment = predictor.depth_alignment |
|
|
| def forward(self, image, disparity_factor): |
| monodepth_out = self.monodepth_model(image) |
| disp = monodepth_out.disparity |
| disp_factor = disparity_factor[:, None, None, None] |
| disp_clamped = disp.clamp(min=1e-4, max=1e4) |
| depth = disp_factor / disp_clamped |
| depth, _ = self.depth_alignment(depth, None, monodepth_out.decoder_features) |
| init_out = self.init_model(image, depth) |
| feats = self.feature_model(init_out.feature_input, encodings=monodepth_out.output_features) |
| deltas = self.prediction_head(feats) |
| gaussians = self.gaussian_composer(deltas, init_out.gaussian_base_values, init_out.global_scale) |
| quats = gaussians.quaternions |
| qnorm = torch.sqrt(torch.clamp(torch.sum(quats * quats, dim=-1, keepdim=True), min=1e-12)) |
| quats = quats / qnorm |
| abs_q = torch.abs(quats) |
| max_idx = torch.argmax(abs_q, dim=-1, keepdim=True) |
| one_hot = torch.zeros_like(quats) |
| one_hot.scatter_(-1, max_idx, 1.0) |
| max_sign = torch.sum(quats * one_hot, dim=-1, keepdim=True) |
| quats = torch.where(max_sign < 0, -quats, quats).float() |
| return (gaussians.mean_vectors, gaussians.singular_values, quats, gaussians.colors, gaussians.opacities) |
|
|
|
|
| def cleanup_onnx_files(onnx_path): |
| """Clean up ONNX model files including external data files.""" |
| try: |
| if onnx_path.exists(): |
| onnx_path.unlink() |
| LOGGER.info(f"Removed {onnx_path}") |
| except Exception as e: |
| LOGGER.warning(f"Could not remove {onnx_path}: {e}") |
| |
| |
| data_path = onnx_path.with_suffix('.onnx.data') |
| try: |
| if data_path.exists(): |
| data_path.unlink() |
| LOGGER.info(f"Removed {data_path}") |
| except Exception as e: |
| LOGGER.warning(f"Could not remove {data_path}: {e}") |
| |
| |
| temp_patterns = ["onnx__*", "monodepth_*", "feature_model*", "_Constant_*", "_init_model_*"] |
| import glob |
| for pattern in temp_patterns: |
| for f in glob.glob(pattern): |
| try: |
| Path(f).unlink() |
| LOGGER.info(f"Removed temporary file {f}") |
| except Exception: |
| pass |
|
|
|
|
| def cleanup_extraneous_files(): |
| import glob, os |
| patterns = ["onnx__*", "monodepth_*", "feature_model*", "_Constant_*", "_init_model_*"] |
| for p in patterns: |
| for f in glob.glob(p): |
| try: |
| os.remove(f) |
| except Exception: |
| pass |
|
|
|
|
| def load_sharp_model(checkpoint_path=None): |
| if checkpoint_path is None: |
| LOGGER.info(f"Downloading model from {DEFAULT_MODEL_URL}") |
| state_dict = torch.hub.load_state_dict_from_url(DEFAULT_MODEL_URL, progress=True) |
| else: |
| LOGGER.info(f"Loading checkpoint from {checkpoint_path}") |
| state_dict = torch.load(checkpoint_path, weights_only=True, map_location="cpu") |
| predictor = create_predictor(PredictorParams()) |
| predictor.load_state_dict(state_dict) |
| predictor.eval() |
| return predictor |
|
|
|
|
| def convert_to_fp16(onnx_path): |
| """Convert an ONNX model to FP16 precision. |
| |
| This function loads an ONNX model, converts all float32 initializers to float16, |
| and also updates the input/output types to float16 for proper execution. |
| The result is a smaller model with faster inference on FP16-capable hardware. |
| """ |
| LOGGER.info(f"Converting {onnx_path} to FP16...") |
| |
| |
| model = onnx.load(str(onnx_path)) |
| |
| |
| for tensor in model.graph.initializer: |
| if tensor.data_type == onnx.TensorProto.FLOAT: |
| float16_tensor = onnx.numpy_helper.to_array(tensor).astype(np.float16) |
| tensor.CopyFrom(onnx.numpy_helper.from_array(float16_tensor, tensor.name)) |
| |
| |
| for inp in model.graph.input: |
| |
| if any(init.name == inp.name for init in model.graph.initializer): |
| continue |
| if inp.type.tensor_type.elem_type == onnx.TensorProto.FLOAT: |
| inp.type.tensor_type.elem_type = onnx.TensorProto.FLOAT16 |
| |
| |
| for out in model.graph.output: |
| if out.type.tensor_type.elem_type == onnx.TensorProto.FLOAT: |
| out.type.tensor_type.elem_type = onnx.TensorProto.FLOAT16 |
| |
| |
| for opset in model.opset_import: |
| if opset.domain == "" and opset.version < 13: |
| opset.version = 13 |
| |
| |
| has_ai_onnx_edge = False |
| for opset in model.opset_import: |
| if opset.domain == "com.microsoft": |
| has_ai_onnx_edge = True |
| break |
| |
| if not has_ai_onnx_edge: |
| opset = model.opset_import.add() |
| opset.domain = "com.microsoft" |
| opset.version = 1 |
| |
| |
| onnx.save(model, str(onnx_path)) |
| |
| size_mb = Path(onnx_path).stat().st_size / (1024 * 1024) |
| LOGGER.info(f"FP16 model saved: {onnx_path} ({size_mb:.2f} MB)") |
| return onnx_path |
|
|
|
|
| def convert_to_onnx(predictor, output_path, input_shape=(1536, 1536), use_external_data=None, fp16=False): |
| LOGGER.info("Exporting to ONNX format...") |
| predictor.depth_alignment.scale_map_estimator = None |
| model = SharpModelTraceable(predictor) |
| model.eval() |
| |
| LOGGER.info("Pre-warming model...") |
| with torch.no_grad(): |
| for _ in range(3): |
| _ = model(torch.randn(1, 3, input_shape[0], input_shape[1]), torch.tensor([1.0])) |
| |
| cleanup_onnx_files(output_path) |
| |
| h, w = input_shape |
| torch.manual_seed(42) |
| example_image = torch.randn(1, 3, h, w) |
| example_disparity = torch.tensor([1.0]) |
| |
| LOGGER.info(f"Exporting to ONNX: {output_path}") |
| |
| |
| |
| dynamic_axes = {} |
| for name in OUTPUT_NAMES: |
| if name == "opacities_alpha_channel": |
| |
| dynamic_axes[name] = {0: 'batch', 1: 'num_gaussians'} |
| else: |
| |
| dynamic_axes[name] = {0: 'batch', 1: 'num_gaussians'} |
| |
| torch.onnx.export( |
| model, (example_image, example_disparity), str(output_path), |
| export_params=True, verbose=False, |
| input_names=['image', 'disparity_factor'], |
| output_names=OUTPUT_NAMES, |
| dynamic_axes=dynamic_axes, |
| opset_version=15, |
| ) |
| |
| |
| try: |
| model_proto = onnx.load(str(output_path)) |
| model_size_mb = model_proto.ByteSize() / (1024 * 1024) |
| LOGGER.info(f"Model size: {model_size_mb:.2f} MB") |
| |
| |
| |
| |
| use_ext = use_external_data if use_external_data is not None else (model_size_mb > 100) |
| |
| if use_ext: |
| LOGGER.info("Saving with external data format...") |
| data_path = output_path.with_suffix('.onnx.data') |
| onnx.save_model(model_proto, str(output_path), save_as_external_data=True, |
| all_tensors_to_one_file=True, location=data_path.name) |
| LOGGER.info(f"External data saved to: {data_path}") |
| else: |
| LOGGER.info("Using inline data format (no external .onnx.data file needed)") |
| except Exception as e: |
| LOGGER.warning(f"External data format check failed: {e}") |
| |
| try: |
| onnx.checker.check_model(str(output_path)) |
| LOGGER.info("ONNX model validation passed") |
| except Exception as e: |
| LOGGER.warning(f"ONNX model validation skipped: {e}") |
| |
| |
| if fp16: |
| convert_to_fp16(output_path) |
| |
| cleanup_extraneous_files() |
| return output_path |
|
|
|
|
| def find_onnx_output_key(name, onnx_outputs): |
| if name in onnx_outputs: |
| return name |
| for key in onnx_outputs: |
| if name.split('_')[0] in key.lower(): |
| return key |
| return list(onnx_outputs.keys())[OUTPUT_NAMES.index(name) if name in OUTPUT_NAMES else 0] |
|
|
|
|
| def load_and_preprocess_image(image_path, target_size=(1536, 1536)): |
| LOGGER.info(f"Loading image from {image_path}") |
| image_np, orig_size, f_px = io.load_rgb(image_path) |
| |
| if orig_size is None: |
| orig_size = (image_np.shape[1], image_np.shape[0]) |
| LOGGER.info(f"Original size: {orig_size}, focal: {f_px:.2f}px") |
| tensor = torch.from_numpy(image_np).float() / 255.0 |
| tensor = tensor.permute(2, 0, 1) |
| if (orig_size[0], orig_size[1]) != (target_size[1], target_size[0]): |
| LOGGER.info(f"Resizing to {target_size[1]}x{target_size[0]}") |
| tensor = F.interpolate(tensor.unsqueeze(0), size=target_size, mode="bilinear", align_corners=True).squeeze(0) |
| tensor = tensor.unsqueeze(0) |
| LOGGER.info(f"Preprocessed shape: {tensor.shape}, range: [{tensor.min():.4f}, {tensor.max():.4f}]") |
| return tensor, f_px, orig_size |
|
|
|
|
| def run_inference_pair(pytorch_model, onnx_path, image_tensor, disparity_factor=1.0, log_internals=False): |
| wrapper = SharpModelTraceable(pytorch_model) |
| wrapper.eval() |
| image_tensor = image_tensor.float() |
| disp_pt = torch.tensor([disparity_factor], dtype=torch.float32) |
| with torch.no_grad(): |
| pt_outputs = wrapper(image_tensor, disp_pt) |
| |
| pt_np = [o.numpy() for o in pt_outputs] |
| |
| session = ort.InferenceSession(str(onnx_path), providers=['CPUExecutionProvider']) |
| onnx_inputs = {"image": image_tensor.numpy(), "disparity_factor": np.array([disparity_factor], dtype=np.float32)} |
| onnx_raw = session.run(None, onnx_inputs) |
| |
| LOGGER.info(f"ONNX raw outputs count: {len(onnx_raw)}, first shape: {onnx_raw[0].shape if len(onnx_raw) > 0 else 'N/A'}") |
| |
| |
| if len(onnx_raw) == 5: |
| |
| onnx_splits = list(onnx_raw) |
| elif len(onnx_raw) == 1: |
| |
| total_size = onnx_raw[0].shape[-1] |
| LOGGER.info(f"ONNX single output total size: {total_size}") |
| |
| |
| sizes = [3, 3, 4, 3, 1] |
| start = 0 |
| onnx_splits = [] |
| for i, size in enumerate(sizes): |
| onnx_splits.append(onnx_raw[0][:, :, start:start+size]) |
| start += size |
| else: |
| onnx_splits = list(onnx_raw) |
| |
| return pt_np, onnx_splits |
|
|
|
|
| def format_validation_table(results, image_name="", include_image=False): |
| lines = [] |
| if include_image: |
| lines.append("| Image | Output | Max Diff | Mean Diff | P99 Diff | Status |") |
| lines.append("|-------|--------|----------|-----------|----------|--------|") |
| for r in results: |
| name = r["output"].replace("_", " ").title() |
| status = "PASS" if r["passed"] else "FAIL" |
| lines.append(f"| {image_name} | {name} | {r['max_diff']} | {r['mean_diff']} | {r['p99_diff']} | {status} |") |
| else: |
| lines.append("| Output | Max Diff | Mean Diff | P99 Diff | Status |") |
| lines.append("|--------|----------|-----------|----------|--------|") |
| for r in results: |
| name = r["output"].replace("_", " ").title() |
| status = "PASS" if r["passed"] else "FAIL" |
| lines.append(f"| {name} | {r['max_diff']} | {r['mean_diff']} | {r['p99_diff']} | {status} |") |
| return "\n".join(lines) |
|
|
|
|
| def validate_with_image(onnx_path, pytorch_model, image_path, input_shape=(1536, 1536)): |
| LOGGER.info(f"Validating with image: {image_path}") |
| test_image, f_px, (w, h) = load_and_preprocess_image(image_path, input_shape) |
| disparity_factor = f_px / w |
| LOGGER.info(f"Using disparity_factor = {disparity_factor:.6f}") |
| |
| pt_outputs, onnx_out = run_inference_pair(pytorch_model, onnx_path, test_image, disparity_factor) |
| |
| LOGGER.info(f"PyTorch outputs shapes: {[o.shape for o in pt_outputs]}") |
| LOGGER.info(f"ONNX output shapes: {[o.shape for o in onnx_out]}") |
| |
| tolerance_config = ToleranceConfig() |
| tolerances = tolerance_config.image_tolerances |
| quat_validator = QuaternionValidator(angular_tolerances=tolerance_config.angular_tolerances_image) |
| |
| all_passed = True |
| results = [] |
| |
| for i, name in enumerate(OUTPUT_NAMES): |
| pt_out = pt_outputs[i] |
| onnx_output = onnx_out[i] |
| |
| result = {"output": name, "passed": True, "failure_reason": ""} |
| |
| if name == "quaternions_rotations": |
| quat_result = quat_validator.validate(pt_out, onnx_output, image_path.name) |
| result.update({ |
| "max_diff": f"{quat_result['stats']['max']:.6f}", |
| "mean_diff": f"{quat_result['stats']['mean']:.6f}", |
| "p99_diff": f"{quat_result['stats']['p99']:.6f}", |
| "passed": quat_result["passed"], |
| "failure_reason": "; ".join(quat_result["failure_reasons"]), |
| }) |
| if not quat_result["passed"]: |
| all_passed = False |
| else: |
| diff = np.abs(pt_out - onnx_output) |
| tol = tolerances.get(name, 0.01) |
| result.update({ |
| "max_diff": f"{np.max(diff):.6f}", |
| "mean_diff": f"{np.mean(diff):.6f}", |
| "p99_diff": f"{np.percentile(diff, 99):.6f}", |
| }) |
| if np.max(diff) > tol: |
| result["passed"] = False |
| result["failure_reason"] = f"max diff {np.max(diff):.6f} > tol {tol:.6f}" |
| all_passed = False |
| |
| results.append(result) |
| |
| LOGGER.info(f"\n### Validation Results: {image_path.name}\n") |
| LOGGER.info(format_validation_table(results, image_path.name, include_image=True)) |
| LOGGER.info("") |
| |
| return all_passed |
|
|
|
|
| def validate_onnx_model(onnx_path, pytorch_model, input_shape=(1536, 1536), angular_tolerances=None): |
| LOGGER.info("Validating ONNX model against PyTorch...") |
| np.random.seed(42) |
| torch.manual_seed(42) |
| |
| test_image = np.random.rand(1, 3, input_shape[0], input_shape[1]).astype(np.float32) |
| test_disp = np.array([1.0], dtype=np.float32) |
| |
| wrapper = SharpModelTraceable(pytorch_model) |
| wrapper.eval() |
| |
| with torch.no_grad(): |
| pt_out = wrapper(torch.from_numpy(test_image), torch.from_numpy(test_disp)) |
| |
| session = ort.InferenceSession(str(onnx_path), providers=['CPUExecutionProvider']) |
| onnx_raw = session.run(None, {"image": test_image, "disparity_factor": test_disp}) |
| |
| |
| if len(onnx_raw) == 5: |
| onnx_splits = list(onnx_raw) |
| elif len(onnx_raw) == 1: |
| sizes = [3, 3, 4, 3, 1] |
| start = 0 |
| onnx_splits = [] |
| for size in sizes: |
| onnx_splits.append(onnx_raw[0][:, :, start:start+size]) |
| start += size |
| else: |
| onnx_splits = list(onnx_raw) |
| |
| tolerance_config = ToleranceConfig() |
| tolerances = tolerance_config.random_tolerances |
| quat_validator = QuaternionValidator(angular_tolerances=angular_tolerances or tolerance_config.angular_tolerances_random) |
| |
| all_passed = True |
| results = [] |
| |
| for i, name in enumerate(OUTPUT_NAMES): |
| pt_o = pt_out[i].numpy() |
| onnx_o = onnx_splits[i] |
| result = {"output": name, "passed": True, "failure_reason": ""} |
| |
| if name == "quaternions_rotations": |
| qr = quat_validator.validate(pt_o, onnx_o, "Random") |
| result.update({ |
| "max_diff": f"{qr['stats']['max']:.6f}", |
| "mean_diff": f"{qr['stats']['mean']:.6f}", |
| "p99_diff": f"{qr['stats']['p99']:.6f}", |
| "passed": qr["passed"], |
| "failure_reason": "; ".join(qr["failure_reasons"]), |
| }) |
| if not qr["passed"]: |
| all_passed = False |
| else: |
| diff = np.abs(pt_o - onnx_o) |
| tol = tolerances.get(name, 0.01) |
| result.update({ |
| "max_diff": f"{np.max(diff):.6f}", |
| "mean_diff": f"{np.mean(diff):.6f}", |
| "p99_diff": f"{np.percentile(diff, 99):.6f}", |
| }) |
| if np.max(diff) > tol: |
| result["passed"] = False |
| result["failure_reason"] = f"max diff {np.max(diff):.6f} > tol {tol:.6f}" |
| all_passed = False |
| |
| results.append(result) |
| |
| LOGGER.info("\n### Random Validation Results\n") |
| LOGGER.info(format_validation_table(results)) |
| LOGGER.info("") |
| |
| return all_passed |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Convert SHARP PyTorch model to ONNX format") |
| parser.add_argument("-c", "--checkpoint", type=Path, default=None, help="Path to PyTorch checkpoint") |
| parser.add_argument("-o", "--output", type=Path, default=Path("sharp.onnx"), help="Output path for ONNX model") |
| parser.add_argument("--height", type=int, default=1536, help="Input image height") |
| parser.add_argument("--width", type=int, default=1536, help="Input image width") |
| parser.add_argument("--validate", action="store_true", help="Validate ONNX model against PyTorch") |
| parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose logging") |
| parser.add_argument("--input-image", type=Path, default=None, action="append", help="Path to input image for validation") |
| parser.add_argument("--no-external-data", action="store_true", help="Save model with inline data (no .onnx.data file needed)") |
| parser.add_argument("--fp16", action="store_true", help="Quantize model to FP16 precision (half-precision)") |
| parser.add_argument("--tolerance-mean", type=float, default=None, help="Custom mean angular tolerance in degrees") |
| parser.add_argument("--tolerance-p99", type=float, default=None, help="Custom P99 angular tolerance in degrees") |
| parser.add_argument("--tolerance-max", type=float, default=None, help="Custom max angular tolerance in degrees") |
| |
| args = parser.parse_args() |
| |
| logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO, |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
| |
| LOGGER.info("Loading SHARP model...") |
| predictor = load_sharp_model(args.checkpoint) |
| |
| input_shape = (args.height, args.width) |
| |
| LOGGER.info(f"Converting to ONNX: {args.output}") |
| |
| use_external_data = not args.no_external_data |
| convert_to_onnx(predictor, args.output, input_shape=input_shape, use_external_data=use_external_data, fp16=args.fp16) |
| LOGGER.info(f"ONNX model saved to {args.output}") |
| |
| |
| if args.validate and args.fp16: |
| LOGGER.info("Validation skipped for FP16 model (precision differences expected)") |
| LOGGER.info("Conversion complete!") |
| return 0 |
| |
| if args.validate: |
| if args.input_image: |
| for img_path in args.input_image: |
| if not img_path.exists(): |
| LOGGER.error(f"Image not found: {img_path}") |
| return 1 |
| passed = validate_with_image(args.output, predictor, img_path, input_shape) |
| if not passed: |
| LOGGER.error(f"Validation failed for {img_path}") |
| return 1 |
| else: |
| angular_tolerances = None |
| if args.tolerance_mean or args.tolerance_p99 or args.tolerance_max: |
| angular_tolerances = { |
| "mean": args.tolerance_mean if args.tolerance_mean else 0.01, |
| "p99": args.tolerance_p99 if args.tolerance_p99 else 0.5, |
| "p99_9": 2.0, |
| "max": args.tolerance_max if args.tolerance_max else 15.0, |
| } |
| passed = validate_onnx_model(args.output, predictor, input_shape, angular_tolerances=angular_tolerances) |
| if passed: |
| LOGGER.info("Validation passed!") |
| else: |
| LOGGER.error("Validation failed!") |
| return 1 |
| |
| LOGGER.info("Conversion complete!") |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| exit(main()) |
|
|