Sharp-onnx / inference_onnx.py
Kyle Pearson
Update framework to ONNX Runtime (FP32/FP16), remove Apple dependencies, add validation script for ONNX conversion with FP32-preserving ops, fix FP16 precision issues, update inference CLI with depth exaggeration, rename docs, and enable LFS support.
5cd2df6
#!/usr/bin/env python3
"""ONNX Inference Script for SHARP Model.
Loads an ONNX model (fp32 or fp16), runs inference on an input image,
and exports the result as a PLY file.
Usage:
# Convert and validate FP16 model
python convert_onnx.py -o sharp_fp16.onnx -q fp16 --validate
# Run inference with FP16 model
python inference_onnx.py -m sharp_fp16.onnx -i test.png -o test.ply -d 0.5
"""
from __future__ import annotations
import argparse
import logging
from pathlib import Path
import numpy as np
import onnxruntime as ort
from PIL import Image
from plyfile import PlyData, PlyElement
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
LOGGER = logging.getLogger(__name__)
DEFAULT_HEIGHT = 1536
DEFAULT_WIDTH = 1536
def linear_to_srgb(linear: float) -> float:
if linear <= 0.0031308:
return linear * 12.92
return 1.055 * pow(linear, 1.0 / 2.4) - 0.055
def rgb_to_sh(rgb: float) -> float:
coeff_degree0 = 1.0 / np.sqrt(4.0 * np.pi)
return (rgb - 0.5) / coeff_degree0
def inverse_sigmoid(x: float) -> float:
x = np.clip(x, 1e-6, 1.0 - 1e-6)
return np.log(x / (1.0 - x))
def preprocess_image(image_path: str | Path, target_size: tuple[int, int] = (DEFAULT_HEIGHT, DEFAULT_WIDTH)):
"""Load and preprocess an image for ONNX inference."""
image_path = Path(image_path)
target_h, target_w = target_size
img = Image.open(image_path)
original_size = img.size
focal_length_px = original_size[0]
if img.size != (target_w, target_h):
img = img.resize((target_w, target_h), Image.BILINEAR)
img_np = np.array(img, dtype=np.float32) / 255.0
if img_np.shape[2] == 4:
img_np = img_np[:, :, :3]
img_np = np.transpose(img_np, (2, 0, 1))
img_np = np.expand_dims(img_np, axis=0)
LOGGER.info(f"Loaded image: {image_path}, original size: {original_size}")
LOGGER.info(f"Preprocessed shape: {img_np.shape}, range: [{img_np.min():.4f}, {img_np.max():.4f}]")
return img_np, float(focal_length_px), original_size
def run_inference(onnx_path: str | Path, image: np.ndarray, disparity_factor: float = 1.0) -> dict[str, np.ndarray]:
"""Run ONNX inference on the preprocessed image."""
onnx_path = Path(onnx_path)
LOGGER.info(f"Loading ONNX model: {onnx_path}")
# Configure session to suppress constant folding warnings for FP16 ops
# These warnings are benign - FP16 Sqrt/Tile ops run correctly but can't be pre-folded
sess_options = ort.SessionOptions()
sess_options.log_severity_level = 3 # 0=Verbose, 1=Info, 2=Warning, 3=Error, 4=Fatal
# Use CPUExecutionProvider for universal compatibility
# Works on all platforms and handles large models with external data files
session = ort.InferenceSession(str(onnx_path), sess_options, providers=['CPUExecutionProvider'])
LOGGER.info("Using CPUExecutionProvider for inference")
input_names = [inp.name for inp in session.get_inputs()]
output_names = [out.name for out in session.get_outputs()]
LOGGER.info(f"Input names: {input_names}")
LOGGER.info(f"Output names: {output_names}")
inputs = {
"image": image.astype(np.float32),
"disparity_factor": np.array([disparity_factor], dtype=np.float32)
}
LOGGER.info("Running inference...")
raw_outputs = session.run(None, inputs)
outputs = {}
if len(raw_outputs) == 1:
concat = raw_outputs[0]
sizes = [3, 3, 4, 3, 1]
names = [
"mean_vectors_3d_positions",
"singular_values_scales",
"quaternions_rotations",
"colors_rgb_linear",
"opacities_alpha_channel"
]
start = 0
for name, size in zip(names, sizes):
outputs[name] = concat[:, :, start:start + size]
start += size
elif len(raw_outputs) == 5:
names = [
"mean_vectors_3d_positions",
"singular_values_scales",
"quaternions_rotations",
"colors_rgb_linear",
"opacities_alpha_channel"
]
for name, out in zip(names, raw_outputs):
outputs[name] = out
else:
for name, out in zip(output_names, raw_outputs):
outputs[name] = out
for name, arr in outputs.items():
LOGGER.info(f" {name}: shape {arr.shape}")
return outputs
def export_ply(outputs: dict[str, np.ndarray], output_path: str | Path,
focal_length_px: float, image_shape: tuple[int, int],
decimation: float = 1.0, depth_scale: float = 1.0) -> None:
"""Export Gaussians to PLY file format."""
output_path = Path(output_path)
mean_vectors = outputs["mean_vectors_3d_positions"]
singular_values = outputs["singular_values_scales"]
quaternions = outputs["quaternions_rotations"]
colors = outputs["colors_rgb_linear"]
opacities = outputs["opacities_alpha_channel"]
mean_vectors = mean_vectors[0]
singular_values = singular_values[0]
quaternions = quaternions[0]
colors = colors[0]
opacities = opacities[0]
num_gaussians = mean_vectors.shape[0]
LOGGER.info(f"Exporting {num_gaussians} Gaussians to PLY")
if decimation < 1.0:
log_scales = np.log(np.maximum(singular_values, 1e-10))
scale_product = np.exp(np.sum(log_scales, axis=1))
importance = scale_product * opacities
indices = np.argsort(-importance)
keep_count = max(1, int(num_gaussians * decimation))
keep_indices = indices[:keep_count]
keep_indices.sort()
LOGGER.info(f"Decimating: keeping {keep_count} of {num_gaussians} ({decimation * 100:.1f}%)")
mean_vectors = mean_vectors[keep_indices]
singular_values = singular_values[keep_indices]
quaternions = quaternions[keep_indices]
colors = colors[keep_indices]
opacities = opacities[keep_indices]
num_gaussians = keep_count
vertex_data = np.zeros(num_gaussians, dtype=[
('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
('f_dc_0', 'f4'), ('f_dc_1', 'f4'), ('f_dc_2', 'f4'),
('opacity', 'f4'),
('scale_0', 'f4'), ('scale_1', 'f4'), ('scale_2', 'f4'),
('rot_0', 'f4'), ('rot_1', 'f4'), ('rot_2', 'f4'), ('rot_3', 'f4')
])
# Model outputs [z*x_ndc, z*y_ndc, z] where z is normalized depth and x_ndc, y_ndc ∈ [-1, 1]
# The model's depth is scale-invariant and normalized to a small range (typically ~0.5-0.7)
# We need to:
# 1. Expand the depth range for proper 3D relief
# 2. Convert projective coords to camera space: x_cam = (z*x_ndc) / focal_ndc
img_h, img_w = image_shape
z_raw = mean_vectors[:, 2]
# Normalize depth to start at 1.0 and scale for better 3D relief
# depth_scale > 1.0 exaggerates depth differences (useful for flat scenes)
z_min = np.min(z_raw)
z_normalized = z_raw / z_min # Now min depth = 1.0
# Apply depth scale to exaggerate depth differences around the median
if depth_scale != 1.0:
z_median = np.median(z_normalized)
z_normalized = z_median + (z_normalized - z_median) * depth_scale
# Scale factor to convert from NDC to camera space
# For a camera with focal length f and image width w: focal_ndc = 2*f/w
# With f = w (90° FOV assumption): focal_ndc = 2.0
focal_ndc = 2.0 * focal_length_px / img_w
# Compute camera-space coordinates
# The projective values need to be scaled by the same depth normalization
scale_factor = 1.0 / (z_min * focal_ndc)
vertex_data['x'] = mean_vectors[:, 0] * scale_factor
vertex_data['y'] = mean_vectors[:, 1] * scale_factor
vertex_data['z'] = z_normalized
LOGGER.info(f"Depth range: {z_raw.min():.3f} - {z_raw.max():.3f} -> normalized: 1.0 - {z_normalized.max():.3f}")
for i in range(num_gaussians):
r, g, b = colors[i]
srgb_r = linear_to_srgb(float(r))
srgb_g = linear_to_srgb(float(g))
srgb_b = linear_to_srgb(float(b))
vertex_data['f_dc_0'][i] = rgb_to_sh(srgb_r)
vertex_data['f_dc_1'][i] = rgb_to_sh(srgb_g)
vertex_data['f_dc_2'][i] = rgb_to_sh(srgb_b)
vertex_data['opacity'] = inverse_sigmoid(opacities)
# Scale the Gaussian sizes to match the transformed coordinate space
vertex_data['scale_0'] = np.log(np.maximum(singular_values[:, 0] * scale_factor, 1e-10))
vertex_data['scale_1'] = np.log(np.maximum(singular_values[:, 1] * scale_factor, 1e-10))
vertex_data['scale_2'] = np.log(np.maximum(singular_values[:, 2] / z_min, 1e-10)) # Z scale uses depth normalization
vertex_data['rot_0'] = quaternions[:, 0]
vertex_data['rot_1'] = quaternions[:, 1]
vertex_data['rot_2'] = quaternions[:, 2]
vertex_data['rot_3'] = quaternions[:, 3]
vertex_element = PlyElement.describe(vertex_data, 'vertex')
# Extrinsic: 4x4 identity matrix as 16 separate properties
extrinsic_data = np.zeros(1, dtype=[('extrinsic', 'f4', (16,))])
extrinsic_data['extrinsic'][0] = [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1]
extrinsic_element = PlyElement.describe(extrinsic_data, 'extrinsic')
img_h, img_w = image_shape
# Intrinsic: 3x3 matrix as 9 separate properties
intrinsic_data = np.zeros(1, dtype=[('intrinsic', 'f4', (9,))])
intrinsic_data['intrinsic'][0] = [focal_length_px, 0, img_w / 2, 0, focal_length_px, img_h / 2, 0, 0, 1]
intrinsic_element = PlyElement.describe(intrinsic_data, 'intrinsic')
# Image size: 2 separate uint32 properties
image_size_data = np.zeros(1, dtype=[('image_size', 'u4', (2,))])
image_size_data['image_size'][0] = [img_w, img_h]
image_size_element = PlyElement.describe(image_size_data, 'image_size')
# Frame: 2 separate int32 properties
frame_data = np.zeros(1, dtype=[('frame', 'i4', (2,))])
frame_data['frame'][0] = [1, num_gaussians]
frame_element = PlyElement.describe(frame_data, 'frame')
z_values = mean_vectors[:, 2]
z_safe = np.maximum(z_values, 1e-6)
disparities = 1.0 / z_safe
disparities.sort()
disparity_10 = disparities[int(len(disparities) * 0.1)] if len(disparities) > 0 else 0.0
disparity_90 = disparities[int(len(disparities) * 0.9)] if len(disparities) > 0 else 1.0
disparity_data = np.zeros(1, dtype=[('disparity', 'f4', (2,))])
disparity_data['disparity'][0] = [disparity_10, disparity_90]
disparity_element = PlyElement.describe(disparity_data, 'disparity')
# Color space: single uchar property
color_space_data = np.zeros(1, dtype=[('color_space', 'u1')])
color_space_data['color_space'][0] = 1
color_space_element = PlyElement.describe(color_space_data, 'color_space')
# Version: 3 uchar properties
version_data = np.zeros(1, dtype=[('version', 'u1', (3,))])
version_data['version'][0] = [1, 5, 0]
version_element = PlyElement.describe(version_data, 'version')
PlyData([
vertex_element,
extrinsic_element,
intrinsic_element,
image_size_element,
frame_element,
disparity_element,
color_space_element,
version_element
], text=False).write(str(output_path))
LOGGER.info(f"Saved PLY with {num_gaussians} Gaussians to {output_path}")
def main():
parser = argparse.ArgumentParser(
description="ONNX Inference for SHARP - Generate 3D Gaussians from an image"
)
parser.add_argument("-m", "--model", type=str, required=True,
help="Path to ONNX model file")
parser.add_argument("-i", "--input", type=str, required=True,
help="Path to input image")
parser.add_argument("-o", "--output", type=str, required=True,
help="Path to output file (.ply)")
parser.add_argument("-d", "--decimate", type=float, default=1.0,
help="Decimation ratio 0.0-1.0 (default: 1.0 = keep all)")
parser.add_argument("--disparity-factor", type=float, default=1.0,
help="Disparity factor for depth conversion (default: 1.0)")
parser.add_argument("--depth-scale", type=float, default=1.0,
help="Depth exaggeration factor (>1.0 increases 3D relief, default: 1.0)")
args = parser.parse_args()
# Preprocess image
image, focal_length_px, image_shape = preprocess_image(args.input)
# Run inference
outputs = run_inference(args.model, image, args.disparity_factor)
# Export to PLY
export_ply(outputs, args.output, focal_length_px, image_shape, args.decimate, args.depth_scale)
if __name__ == "__main__":
main()