Kyle Pearson
commited on
Commit
·
2cda5f8
1
Parent(s):
9bef2af
inference script
Browse files- inference_onnx.py +302 -0
inference_onnx.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""ONNX Inference Script for SHARP Model.
|
| 3 |
+
|
| 4 |
+
Loads an ONNX model (fp32 or fp16), runs inference on an input image,
|
| 5 |
+
and exports the result as a PLY file.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python inference_onnx.py -m sharp.onnx -i test.png -o output.ply
|
| 9 |
+
python inference_onnx.py -m sharp_inline_fp16.onnx -i test.png -o output.ply -d 0.5
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import logging
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import onnxruntime as ort
|
| 20 |
+
from PIL import Image
|
| 21 |
+
from plyfile import PlyData, PlyElement
|
| 22 |
+
|
| 23 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
| 24 |
+
LOGGER = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
DEFAULT_HEIGHT = 1536
|
| 27 |
+
DEFAULT_WIDTH = 1536
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def linear_to_srgb(linear: float) -> float:
|
| 31 |
+
if linear <= 0.0031308:
|
| 32 |
+
return linear * 12.92
|
| 33 |
+
return 1.055 * pow(linear, 1.0 / 2.4) - 0.055
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def rgb_to_sh(rgb: float) -> float:
|
| 37 |
+
coeff_degree0 = 1.0 / np.sqrt(4.0 * np.pi)
|
| 38 |
+
return (rgb - 0.5) / coeff_degree0
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def inverse_sigmoid(x: float) -> float:
|
| 42 |
+
x = np.clip(x, 1e-6, 1.0 - 1e-6)
|
| 43 |
+
return np.log(x / (1.0 - x))
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def preprocess_image(image_path: str | Path, target_size: tuple[int, int] = (DEFAULT_HEIGHT, DEFAULT_WIDTH)):
|
| 47 |
+
"""Load and preprocess an image for ONNX inference."""
|
| 48 |
+
image_path = Path(image_path)
|
| 49 |
+
target_h, target_w = target_size
|
| 50 |
+
|
| 51 |
+
img = Image.open(image_path)
|
| 52 |
+
original_size = img.size
|
| 53 |
+
focal_length_px = original_size[0]
|
| 54 |
+
|
| 55 |
+
if img.size != (target_w, target_h):
|
| 56 |
+
img = img.resize((target_w, target_h), Image.BILINEAR)
|
| 57 |
+
|
| 58 |
+
img_np = np.array(img, dtype=np.float32) / 255.0
|
| 59 |
+
|
| 60 |
+
if img_np.shape[2] == 4:
|
| 61 |
+
img_np = img_np[:, :, :3]
|
| 62 |
+
|
| 63 |
+
img_np = np.transpose(img_np, (2, 0, 1))
|
| 64 |
+
img_np = np.expand_dims(img_np, axis=0)
|
| 65 |
+
|
| 66 |
+
LOGGER.info(f"Loaded image: {image_path}, original size: {original_size}")
|
| 67 |
+
LOGGER.info(f"Preprocessed shape: {img_np.shape}, range: [{img_np.min():.4f}, {img_np.max():.4f}]")
|
| 68 |
+
|
| 69 |
+
return img_np, float(focal_length_px), original_size
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def run_inference(onnx_path: str | Path, image: np.ndarray, disparity_factor: float = 1.0) -> dict[str, np.ndarray]:
|
| 73 |
+
"""Run ONNX inference on the preprocessed image."""
|
| 74 |
+
onnx_path = Path(onnx_path)
|
| 75 |
+
|
| 76 |
+
LOGGER.info(f"Loading ONNX model: {onnx_path}")
|
| 77 |
+
|
| 78 |
+
# Try with default providers first, then fallback to CPU only
|
| 79 |
+
try:
|
| 80 |
+
session = ort.InferenceSession(str(onnx_path))
|
| 81 |
+
except Exception as e:
|
| 82 |
+
error_msg = str(e)
|
| 83 |
+
if "tensor(float16)" in error_msg and "tensor(float)" in error_msg:
|
| 84 |
+
LOGGER.error("FP16 model has mixed float16/float32 types. This model was converted incorrectly.")
|
| 85 |
+
LOGGER.error("For FP16 inference on Apple Silicon, use the Core ML model (sharp.mlpackage) instead.")
|
| 86 |
+
LOGGER.error("Or regenerate the ONNX model with proper FP16 conversion.")
|
| 87 |
+
raise RuntimeError(f"Invalid FP16 model: {error_msg}")
|
| 88 |
+
# Try CPU fallback
|
| 89 |
+
try:
|
| 90 |
+
session = ort.InferenceSession(str(onnx_path), providers=['CPUExecutionProvider'])
|
| 91 |
+
except Exception as cpu_e:
|
| 92 |
+
raise RuntimeError(f"Failed to load ONNX model: {cpu_e}")
|
| 93 |
+
|
| 94 |
+
input_names = [inp.name for inp in session.get_inputs()]
|
| 95 |
+
output_names = [out.name for out in session.get_outputs()]
|
| 96 |
+
|
| 97 |
+
LOGGER.info(f"Input names: {input_names}")
|
| 98 |
+
LOGGER.info(f"Output names: {output_names}")
|
| 99 |
+
|
| 100 |
+
inputs = {
|
| 101 |
+
"image": image.astype(np.float32),
|
| 102 |
+
"disparity_factor": np.array([disparity_factor], dtype=np.float32)
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
LOGGER.info("Running inference...")
|
| 106 |
+
raw_outputs = session.run(None, inputs)
|
| 107 |
+
|
| 108 |
+
outputs = {}
|
| 109 |
+
|
| 110 |
+
if len(raw_outputs) == 1:
|
| 111 |
+
concat = raw_outputs[0]
|
| 112 |
+
sizes = [3, 3, 4, 3, 1]
|
| 113 |
+
names = [
|
| 114 |
+
"mean_vectors_3d_positions",
|
| 115 |
+
"singular_values_scales",
|
| 116 |
+
"quaternions_rotations",
|
| 117 |
+
"colors_rgb_linear",
|
| 118 |
+
"opacities_alpha_channel"
|
| 119 |
+
]
|
| 120 |
+
start = 0
|
| 121 |
+
for name, size in zip(names, sizes):
|
| 122 |
+
outputs[name] = concat[:, :, start:start + size]
|
| 123 |
+
start += size
|
| 124 |
+
elif len(raw_outputs) == 5:
|
| 125 |
+
names = [
|
| 126 |
+
"mean_vectors_3d_positions",
|
| 127 |
+
"singular_values_scales",
|
| 128 |
+
"quaternions_rotations",
|
| 129 |
+
"colors_rgb_linear",
|
| 130 |
+
"opacities_alpha_channel"
|
| 131 |
+
]
|
| 132 |
+
for name, out in zip(names, raw_outputs):
|
| 133 |
+
outputs[name] = out
|
| 134 |
+
else:
|
| 135 |
+
for name, out in zip(output_names, raw_outputs):
|
| 136 |
+
outputs[name] = out
|
| 137 |
+
|
| 138 |
+
for name, arr in outputs.items():
|
| 139 |
+
LOGGER.info(f" {name}: shape {arr.shape}")
|
| 140 |
+
|
| 141 |
+
return outputs
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def export_ply(outputs: dict[str, np.ndarray], output_path: str | Path,
|
| 145 |
+
focal_length_px: float, image_shape: tuple[int, int],
|
| 146 |
+
decimation: float = 1.0) -> None:
|
| 147 |
+
"""Export Gaussians to PLY file format."""
|
| 148 |
+
output_path = Path(output_path)
|
| 149 |
+
|
| 150 |
+
mean_vectors = outputs["mean_vectors_3d_positions"]
|
| 151 |
+
singular_values = outputs["singular_values_scales"]
|
| 152 |
+
quaternions = outputs["quaternions_rotations"]
|
| 153 |
+
colors = outputs["colors_rgb_linear"]
|
| 154 |
+
opacities = outputs["opacities_alpha_channel"]
|
| 155 |
+
|
| 156 |
+
mean_vectors = mean_vectors[0]
|
| 157 |
+
singular_values = singular_values[0]
|
| 158 |
+
quaternions = quaternions[0]
|
| 159 |
+
colors = colors[0]
|
| 160 |
+
opacities = opacities[0]
|
| 161 |
+
|
| 162 |
+
num_gaussians = mean_vectors.shape[0]
|
| 163 |
+
LOGGER.info(f"Exporting {num_gaussians} Gaussians to PLY")
|
| 164 |
+
|
| 165 |
+
if decimation < 1.0:
|
| 166 |
+
log_scales = np.log(np.maximum(singular_values, 1e-10))
|
| 167 |
+
scale_product = np.exp(np.sum(log_scales, axis=1))
|
| 168 |
+
importance = scale_product * opacities
|
| 169 |
+
|
| 170 |
+
indices = np.argsort(-importance)
|
| 171 |
+
keep_count = max(1, int(num_gaussians * decimation))
|
| 172 |
+
keep_indices = indices[:keep_count]
|
| 173 |
+
keep_indices.sort()
|
| 174 |
+
|
| 175 |
+
LOGGER.info(f"Decimating: keeping {keep_count} of {num_gaussians} ({decimation * 100:.1f}%)")
|
| 176 |
+
|
| 177 |
+
mean_vectors = mean_vectors[keep_indices]
|
| 178 |
+
singular_values = singular_values[keep_indices]
|
| 179 |
+
quaternions = quaternions[keep_indices]
|
| 180 |
+
colors = colors[keep_indices]
|
| 181 |
+
opacities = opacities[keep_indices]
|
| 182 |
+
num_gaussians = keep_count
|
| 183 |
+
|
| 184 |
+
vertex_data = np.zeros(num_gaussians, dtype=[
|
| 185 |
+
('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
|
| 186 |
+
('f_dc_0', 'f4'), ('f_dc_1', 'f4'), ('f_dc_2', 'f4'),
|
| 187 |
+
('opacity', 'f4'),
|
| 188 |
+
('scale_0', 'f4'), ('scale_1', 'f4'), ('scale_2', 'f4'),
|
| 189 |
+
('rot_0', 'f4'), ('rot_1', 'f4'), ('rot_2', 'f4'), ('rot_3', 'f4')
|
| 190 |
+
])
|
| 191 |
+
|
| 192 |
+
vertex_data['x'] = mean_vectors[:, 0]
|
| 193 |
+
vertex_data['y'] = mean_vectors[:, 1]
|
| 194 |
+
vertex_data['z'] = mean_vectors[:, 2]
|
| 195 |
+
|
| 196 |
+
for i in range(num_gaussians):
|
| 197 |
+
r, g, b = colors[i]
|
| 198 |
+
srgb_r = linear_to_srgb(float(r))
|
| 199 |
+
srgb_g = linear_to_srgb(float(g))
|
| 200 |
+
srgb_b = linear_to_srgb(float(b))
|
| 201 |
+
|
| 202 |
+
vertex_data['f_dc_0'][i] = rgb_to_sh(srgb_r)
|
| 203 |
+
vertex_data['f_dc_1'][i] = rgb_to_sh(srgb_g)
|
| 204 |
+
vertex_data['f_dc_2'][i] = rgb_to_sh(srgb_b)
|
| 205 |
+
|
| 206 |
+
vertex_data['opacity'] = inverse_sigmoid(opacities)
|
| 207 |
+
|
| 208 |
+
vertex_data['scale_0'] = np.log(np.maximum(singular_values[:, 0], 1e-10))
|
| 209 |
+
vertex_data['scale_1'] = np.log(np.maximum(singular_values[:, 1], 1e-10))
|
| 210 |
+
vertex_data['scale_2'] = np.log(np.maximum(singular_values[:, 2], 1e-10))
|
| 211 |
+
|
| 212 |
+
vertex_data['rot_0'] = quaternions[:, 0]
|
| 213 |
+
vertex_data['rot_1'] = quaternions[:, 1]
|
| 214 |
+
vertex_data['rot_2'] = quaternions[:, 2]
|
| 215 |
+
vertex_data['rot_3'] = quaternions[:, 3]
|
| 216 |
+
|
| 217 |
+
vertex_element = PlyElement.describe(vertex_data, 'vertex')
|
| 218 |
+
|
| 219 |
+
# Extrinsic: 4x4 identity matrix as 16 separate properties
|
| 220 |
+
extrinsic_data = np.zeros(1, dtype=[('extrinsic', 'f4', (16,))])
|
| 221 |
+
extrinsic_data['extrinsic'][0] = [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1]
|
| 222 |
+
extrinsic_element = PlyElement.describe(extrinsic_data, 'extrinsic')
|
| 223 |
+
|
| 224 |
+
img_h, img_w = image_shape
|
| 225 |
+
# Intrinsic: 3x3 matrix as 9 separate properties
|
| 226 |
+
intrinsic_data = np.zeros(1, dtype=[('intrinsic', 'f4', (9,))])
|
| 227 |
+
intrinsic_data['intrinsic'][0] = [focal_length_px, 0, img_w / 2, 0, focal_length_px, img_h / 2, 0, 0, 1]
|
| 228 |
+
intrinsic_element = PlyElement.describe(intrinsic_data, 'intrinsic')
|
| 229 |
+
|
| 230 |
+
# Image size: 2 separate uint32 properties
|
| 231 |
+
image_size_data = np.zeros(1, dtype=[('image_size', 'u4', (2,))])
|
| 232 |
+
image_size_data['image_size'][0] = [img_w, img_h]
|
| 233 |
+
image_size_element = PlyElement.describe(image_size_data, 'image_size')
|
| 234 |
+
|
| 235 |
+
# Frame: 2 separate int32 properties
|
| 236 |
+
frame_data = np.zeros(1, dtype=[('frame', 'i4', (2,))])
|
| 237 |
+
frame_data['frame'][0] = [1, num_gaussians]
|
| 238 |
+
frame_element = PlyElement.describe(frame_data, 'frame')
|
| 239 |
+
|
| 240 |
+
z_values = mean_vectors[:, 2]
|
| 241 |
+
z_safe = np.maximum(z_values, 1e-6)
|
| 242 |
+
disparities = 1.0 / z_safe
|
| 243 |
+
disparities.sort()
|
| 244 |
+
disparity_10 = disparities[int(len(disparities) * 0.1)] if len(disparities) > 0 else 0.0
|
| 245 |
+
disparity_90 = disparities[int(len(disparities) * 0.9)] if len(disparities) > 0 else 1.0
|
| 246 |
+
disparity_data = np.zeros(1, dtype=[('disparity', 'f4', (2,))])
|
| 247 |
+
disparity_data['disparity'][0] = [disparity_10, disparity_90]
|
| 248 |
+
disparity_element = PlyElement.describe(disparity_data, 'disparity')
|
| 249 |
+
|
| 250 |
+
# Color space: single uchar property
|
| 251 |
+
color_space_data = np.zeros(1, dtype=[('color_space', 'u1')])
|
| 252 |
+
color_space_data['color_space'][0] = 1
|
| 253 |
+
color_space_element = PlyElement.describe(color_space_data, 'color_space')
|
| 254 |
+
|
| 255 |
+
# Version: 3 uchar properties
|
| 256 |
+
version_data = np.zeros(1, dtype=[('version', 'u1', (3,))])
|
| 257 |
+
version_data['version'][0] = [1, 5, 0]
|
| 258 |
+
version_element = PlyElement.describe(version_data, 'version')
|
| 259 |
+
|
| 260 |
+
PlyData([
|
| 261 |
+
vertex_element,
|
| 262 |
+
extrinsic_element,
|
| 263 |
+
intrinsic_element,
|
| 264 |
+
image_size_element,
|
| 265 |
+
frame_element,
|
| 266 |
+
disparity_element,
|
| 267 |
+
color_space_element,
|
| 268 |
+
version_element
|
| 269 |
+
], text=False).write(str(output_path))
|
| 270 |
+
|
| 271 |
+
LOGGER.info(f"Saved PLY with {num_gaussians} Gaussians to {output_path}")
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def main():
|
| 275 |
+
parser = argparse.ArgumentParser(
|
| 276 |
+
description="ONNX Inference for SHARP - Generate 3D Gaussians from an image"
|
| 277 |
+
)
|
| 278 |
+
parser.add_argument("-m", "--model", type=str, required=True,
|
| 279 |
+
help="Path to ONNX model file")
|
| 280 |
+
parser.add_argument("-i", "--input", type=str, required=True,
|
| 281 |
+
help="Path to input image")
|
| 282 |
+
parser.add_argument("-o", "--output", type=str, required=True,
|
| 283 |
+
help="Path to output file (.ply)")
|
| 284 |
+
parser.add_argument("-d", "--decimate", type=float, default=1.0,
|
| 285 |
+
help="Decimation ratio 0.0-1.0 (default: 1.0 = keep all)")
|
| 286 |
+
parser.add_argument("--disparity-factor", type=float, default=1.0,
|
| 287 |
+
help="Disparity factor for depth conversion (default: 1.0)")
|
| 288 |
+
|
| 289 |
+
args = parser.parse_args()
|
| 290 |
+
|
| 291 |
+
# Preprocess image
|
| 292 |
+
image, focal_length_px, image_shape = preprocess_image(args.input)
|
| 293 |
+
|
| 294 |
+
# Run inference
|
| 295 |
+
outputs = run_inference(args.model, image, args.disparity_factor)
|
| 296 |
+
|
| 297 |
+
# Export to PLY
|
| 298 |
+
export_ply(outputs, args.output, focal_length_px, image_shape, args.decimate)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
if __name__ == "__main__":
|
| 302 |
+
main()
|