File size: 12,650 Bytes
2cda5f8 983298e 2cda5f8 5cd2df6 1dd5974 5cd2df6 1dd5974 2cda5f8 5cd2df6 2cda5f8 5cd2df6 2cda5f8 5cd2df6 2cda5f8 5cd2df6 2cda5f8 5cd2df6 2cda5f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 |
#!/usr/bin/env python3
"""ONNX Inference Script for SHARP Model.
Loads an ONNX model (fp32 or fp16), runs inference on an input image,
and exports the result as a PLY file.
Usage:
# Convert and validate FP16 model
python convert_onnx.py -o sharp_fp16.onnx -q fp16 --validate
# Run inference with FP16 model
python inference_onnx.py -m sharp_fp16.onnx -i test.png -o test.ply -d 0.5
"""
from __future__ import annotations
import argparse
import logging
from pathlib import Path
import numpy as np
import onnxruntime as ort
from PIL import Image
from plyfile import PlyData, PlyElement
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
LOGGER = logging.getLogger(__name__)
DEFAULT_HEIGHT = 1536
DEFAULT_WIDTH = 1536
def linear_to_srgb(linear: float) -> float:
if linear <= 0.0031308:
return linear * 12.92
return 1.055 * pow(linear, 1.0 / 2.4) - 0.055
def rgb_to_sh(rgb: float) -> float:
coeff_degree0 = 1.0 / np.sqrt(4.0 * np.pi)
return (rgb - 0.5) / coeff_degree0
def inverse_sigmoid(x: float) -> float:
x = np.clip(x, 1e-6, 1.0 - 1e-6)
return np.log(x / (1.0 - x))
def preprocess_image(image_path: str | Path, target_size: tuple[int, int] = (DEFAULT_HEIGHT, DEFAULT_WIDTH)):
"""Load and preprocess an image for ONNX inference."""
image_path = Path(image_path)
target_h, target_w = target_size
img = Image.open(image_path)
original_size = img.size
focal_length_px = original_size[0]
if img.size != (target_w, target_h):
img = img.resize((target_w, target_h), Image.BILINEAR)
img_np = np.array(img, dtype=np.float32) / 255.0
if img_np.shape[2] == 4:
img_np = img_np[:, :, :3]
img_np = np.transpose(img_np, (2, 0, 1))
img_np = np.expand_dims(img_np, axis=0)
LOGGER.info(f"Loaded image: {image_path}, original size: {original_size}")
LOGGER.info(f"Preprocessed shape: {img_np.shape}, range: [{img_np.min():.4f}, {img_np.max():.4f}]")
return img_np, float(focal_length_px), original_size
def run_inference(onnx_path: str | Path, image: np.ndarray, disparity_factor: float = 1.0) -> dict[str, np.ndarray]:
"""Run ONNX inference on the preprocessed image."""
onnx_path = Path(onnx_path)
LOGGER.info(f"Loading ONNX model: {onnx_path}")
# Configure session to suppress constant folding warnings for FP16 ops
# These warnings are benign - FP16 Sqrt/Tile ops run correctly but can't be pre-folded
sess_options = ort.SessionOptions()
sess_options.log_severity_level = 3 # 0=Verbose, 1=Info, 2=Warning, 3=Error, 4=Fatal
# Use CPUExecutionProvider for universal compatibility
# Works on all platforms and handles large models with external data files
session = ort.InferenceSession(str(onnx_path), sess_options, providers=['CPUExecutionProvider'])
LOGGER.info("Using CPUExecutionProvider for inference")
input_names = [inp.name for inp in session.get_inputs()]
output_names = [out.name for out in session.get_outputs()]
LOGGER.info(f"Input names: {input_names}")
LOGGER.info(f"Output names: {output_names}")
inputs = {
"image": image.astype(np.float32),
"disparity_factor": np.array([disparity_factor], dtype=np.float32)
}
LOGGER.info("Running inference...")
raw_outputs = session.run(None, inputs)
outputs = {}
if len(raw_outputs) == 1:
concat = raw_outputs[0]
sizes = [3, 3, 4, 3, 1]
names = [
"mean_vectors_3d_positions",
"singular_values_scales",
"quaternions_rotations",
"colors_rgb_linear",
"opacities_alpha_channel"
]
start = 0
for name, size in zip(names, sizes):
outputs[name] = concat[:, :, start:start + size]
start += size
elif len(raw_outputs) == 5:
names = [
"mean_vectors_3d_positions",
"singular_values_scales",
"quaternions_rotations",
"colors_rgb_linear",
"opacities_alpha_channel"
]
for name, out in zip(names, raw_outputs):
outputs[name] = out
else:
for name, out in zip(output_names, raw_outputs):
outputs[name] = out
for name, arr in outputs.items():
LOGGER.info(f" {name}: shape {arr.shape}")
return outputs
def export_ply(outputs: dict[str, np.ndarray], output_path: str | Path,
focal_length_px: float, image_shape: tuple[int, int],
decimation: float = 1.0, depth_scale: float = 1.0) -> None:
"""Export Gaussians to PLY file format."""
output_path = Path(output_path)
mean_vectors = outputs["mean_vectors_3d_positions"]
singular_values = outputs["singular_values_scales"]
quaternions = outputs["quaternions_rotations"]
colors = outputs["colors_rgb_linear"]
opacities = outputs["opacities_alpha_channel"]
mean_vectors = mean_vectors[0]
singular_values = singular_values[0]
quaternions = quaternions[0]
colors = colors[0]
opacities = opacities[0]
num_gaussians = mean_vectors.shape[0]
LOGGER.info(f"Exporting {num_gaussians} Gaussians to PLY")
if decimation < 1.0:
log_scales = np.log(np.maximum(singular_values, 1e-10))
scale_product = np.exp(np.sum(log_scales, axis=1))
importance = scale_product * opacities
indices = np.argsort(-importance)
keep_count = max(1, int(num_gaussians * decimation))
keep_indices = indices[:keep_count]
keep_indices.sort()
LOGGER.info(f"Decimating: keeping {keep_count} of {num_gaussians} ({decimation * 100:.1f}%)")
mean_vectors = mean_vectors[keep_indices]
singular_values = singular_values[keep_indices]
quaternions = quaternions[keep_indices]
colors = colors[keep_indices]
opacities = opacities[keep_indices]
num_gaussians = keep_count
vertex_data = np.zeros(num_gaussians, dtype=[
('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
('f_dc_0', 'f4'), ('f_dc_1', 'f4'), ('f_dc_2', 'f4'),
('opacity', 'f4'),
('scale_0', 'f4'), ('scale_1', 'f4'), ('scale_2', 'f4'),
('rot_0', 'f4'), ('rot_1', 'f4'), ('rot_2', 'f4'), ('rot_3', 'f4')
])
# Model outputs [z*x_ndc, z*y_ndc, z] where z is normalized depth and x_ndc, y_ndc ∈ [-1, 1]
# The model's depth is scale-invariant and normalized to a small range (typically ~0.5-0.7)
# We need to:
# 1. Expand the depth range for proper 3D relief
# 2. Convert projective coords to camera space: x_cam = (z*x_ndc) / focal_ndc
img_h, img_w = image_shape
z_raw = mean_vectors[:, 2]
# Normalize depth to start at 1.0 and scale for better 3D relief
# depth_scale > 1.0 exaggerates depth differences (useful for flat scenes)
z_min = np.min(z_raw)
z_normalized = z_raw / z_min # Now min depth = 1.0
# Apply depth scale to exaggerate depth differences around the median
if depth_scale != 1.0:
z_median = np.median(z_normalized)
z_normalized = z_median + (z_normalized - z_median) * depth_scale
# Scale factor to convert from NDC to camera space
# For a camera with focal length f and image width w: focal_ndc = 2*f/w
# With f = w (90° FOV assumption): focal_ndc = 2.0
focal_ndc = 2.0 * focal_length_px / img_w
# Compute camera-space coordinates
# The projective values need to be scaled by the same depth normalization
scale_factor = 1.0 / (z_min * focal_ndc)
vertex_data['x'] = mean_vectors[:, 0] * scale_factor
vertex_data['y'] = mean_vectors[:, 1] * scale_factor
vertex_data['z'] = z_normalized
LOGGER.info(f"Depth range: {z_raw.min():.3f} - {z_raw.max():.3f} -> normalized: 1.0 - {z_normalized.max():.3f}")
for i in range(num_gaussians):
r, g, b = colors[i]
srgb_r = linear_to_srgb(float(r))
srgb_g = linear_to_srgb(float(g))
srgb_b = linear_to_srgb(float(b))
vertex_data['f_dc_0'][i] = rgb_to_sh(srgb_r)
vertex_data['f_dc_1'][i] = rgb_to_sh(srgb_g)
vertex_data['f_dc_2'][i] = rgb_to_sh(srgb_b)
vertex_data['opacity'] = inverse_sigmoid(opacities)
# Scale the Gaussian sizes to match the transformed coordinate space
vertex_data['scale_0'] = np.log(np.maximum(singular_values[:, 0] * scale_factor, 1e-10))
vertex_data['scale_1'] = np.log(np.maximum(singular_values[:, 1] * scale_factor, 1e-10))
vertex_data['scale_2'] = np.log(np.maximum(singular_values[:, 2] / z_min, 1e-10)) # Z scale uses depth normalization
vertex_data['rot_0'] = quaternions[:, 0]
vertex_data['rot_1'] = quaternions[:, 1]
vertex_data['rot_2'] = quaternions[:, 2]
vertex_data['rot_3'] = quaternions[:, 3]
vertex_element = PlyElement.describe(vertex_data, 'vertex')
# Extrinsic: 4x4 identity matrix as 16 separate properties
extrinsic_data = np.zeros(1, dtype=[('extrinsic', 'f4', (16,))])
extrinsic_data['extrinsic'][0] = [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1]
extrinsic_element = PlyElement.describe(extrinsic_data, 'extrinsic')
img_h, img_w = image_shape
# Intrinsic: 3x3 matrix as 9 separate properties
intrinsic_data = np.zeros(1, dtype=[('intrinsic', 'f4', (9,))])
intrinsic_data['intrinsic'][0] = [focal_length_px, 0, img_w / 2, 0, focal_length_px, img_h / 2, 0, 0, 1]
intrinsic_element = PlyElement.describe(intrinsic_data, 'intrinsic')
# Image size: 2 separate uint32 properties
image_size_data = np.zeros(1, dtype=[('image_size', 'u4', (2,))])
image_size_data['image_size'][0] = [img_w, img_h]
image_size_element = PlyElement.describe(image_size_data, 'image_size')
# Frame: 2 separate int32 properties
frame_data = np.zeros(1, dtype=[('frame', 'i4', (2,))])
frame_data['frame'][0] = [1, num_gaussians]
frame_element = PlyElement.describe(frame_data, 'frame')
z_values = mean_vectors[:, 2]
z_safe = np.maximum(z_values, 1e-6)
disparities = 1.0 / z_safe
disparities.sort()
disparity_10 = disparities[int(len(disparities) * 0.1)] if len(disparities) > 0 else 0.0
disparity_90 = disparities[int(len(disparities) * 0.9)] if len(disparities) > 0 else 1.0
disparity_data = np.zeros(1, dtype=[('disparity', 'f4', (2,))])
disparity_data['disparity'][0] = [disparity_10, disparity_90]
disparity_element = PlyElement.describe(disparity_data, 'disparity')
# Color space: single uchar property
color_space_data = np.zeros(1, dtype=[('color_space', 'u1')])
color_space_data['color_space'][0] = 1
color_space_element = PlyElement.describe(color_space_data, 'color_space')
# Version: 3 uchar properties
version_data = np.zeros(1, dtype=[('version', 'u1', (3,))])
version_data['version'][0] = [1, 5, 0]
version_element = PlyElement.describe(version_data, 'version')
PlyData([
vertex_element,
extrinsic_element,
intrinsic_element,
image_size_element,
frame_element,
disparity_element,
color_space_element,
version_element
], text=False).write(str(output_path))
LOGGER.info(f"Saved PLY with {num_gaussians} Gaussians to {output_path}")
def main():
parser = argparse.ArgumentParser(
description="ONNX Inference for SHARP - Generate 3D Gaussians from an image"
)
parser.add_argument("-m", "--model", type=str, required=True,
help="Path to ONNX model file")
parser.add_argument("-i", "--input", type=str, required=True,
help="Path to input image")
parser.add_argument("-o", "--output", type=str, required=True,
help="Path to output file (.ply)")
parser.add_argument("-d", "--decimate", type=float, default=1.0,
help="Decimation ratio 0.0-1.0 (default: 1.0 = keep all)")
parser.add_argument("--disparity-factor", type=float, default=1.0,
help="Disparity factor for depth conversion (default: 1.0)")
parser.add_argument("--depth-scale", type=float, default=1.0,
help="Depth exaggeration factor (>1.0 increases 3D relief, default: 1.0)")
args = parser.parse_args()
# Preprocess image
image, focal_length_px, image_shape = preprocess_image(args.input)
# Run inference
outputs = run_inference(args.model, image, args.disparity_factor)
# Export to PLY
export_ply(outputs, args.output, focal_length_px, image_shape, args.decimate, args.depth_scale)
if __name__ == "__main__":
main()
|